From 9d48230c4f8f25bf88c5d6105f97370a36c9cf43 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 11 三月 2024 10:48:50 +0800
Subject: [PATCH] export onnx (#1457)

---
 funasr/models/bicif_paraformer/model.py |   88 +++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 87 insertions(+), 1 deletions(-)

diff --git a/funasr/models/bicif_paraformer/model.py b/funasr/models/bicif_paraformer/model.py
index 696cd56..eb7318b 100644
--- a/funasr/models/bicif_paraformer/model.py
+++ b/funasr/models/bicif_paraformer/model.py
@@ -341,4 +341,90 @@
                     result_i = {"key": key[i], "token_int": token_int}
                 results.append(result_i)
         
-        return results, meta_data
\ No newline at end of file
+        return results, meta_data
+
+    def export(
+        self,
+        max_seq_len=512,
+        **kwargs,
+    ):
+        is_onnx = kwargs.get("type", "onnx") == "onnx"
+        encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
+        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
+    
+        predictor_class = tables.predictor_classes.get(kwargs["predictor"] + "Export")
+        self.predictor = predictor_class(self.predictor, onnx=is_onnx)
+    
+        decoder_class = tables.decoder_classes.get(kwargs["decoder"] + "Export")
+        self.decoder = decoder_class(self.decoder, onnx=is_onnx)
+    
+        from funasr.utils.torch_function import MakePadMask
+        from funasr.utils.torch_function import sequence_mask
+    
+        if is_onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+    
+        self.forward = self._export_forward
+    
+        return self
+
+    def _export_forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+    ):
+        # a. To device
+        batch = {"speech": speech, "speech_lengths": speech_lengths}
+        # batch = to_device(batch, device=self.device)
+    
+        enc, enc_len = self.encoder(**batch)
+        mask = self.make_pad_mask(enc_len)[:, None, :]
+        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
+        pre_token_length = pre_token_length.round().type(torch.int32)
+    
+        decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
+        decoder_out = torch.log_softmax(decoder_out, dim=-1)
+    
+        # get predicted timestamps
+        us_alphas, us_cif_peak = self.predictor.get_upsample_timestmap(enc, mask, pre_token_length)
+    
+        return decoder_out, pre_token_length, us_alphas, us_cif_peak
+
+    def export_dummy_inputs(self):
+        speech = torch.randn(2, 30, 560)
+        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
+        return (speech, speech_lengths)
+
+    def export_input_names(self):
+        return ['speech', 'speech_lengths']
+
+    def export_output_names(self):
+        return ['logits', 'token_num', 'us_alphas', 'us_cif_peak']
+
+    def export_dynamic_axes(self):
+        return {
+            'speech': {
+                0: 'batch_size',
+                1: 'feats_length'
+            },
+            'speech_lengths': {
+                0: 'batch_size',
+            },
+            'logits': {
+                0: 'batch_size',
+                1: 'logits_length'
+            },
+            'us_alphas': {
+                0: 'batch_size',
+                1: 'alphas_length'
+            },
+            'us_cif_peak': {
+                0: 'batch_size',
+                1: 'alphas_length'
+            },
+        }
+
+    def export_name(self, ):
+        return "model.onnx"

--
Gitblit v1.9.1