From e04489ce4c0fd0095d0c79ef8f504f425e0435a8 Mon Sep 17 00:00:00 2001
From: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
Date: 星期三, 13 三月 2024 16:34:42 +0800
Subject: [PATCH] contextual&seaco ONNX export (#1481)

---
 funasr/models/paraformer/model.py |   85 ++++--------------------------------------
 1 files changed, 9 insertions(+), 76 deletions(-)

diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 41a1bf7..316255d 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -13,15 +13,16 @@
 from funasr.models.ctc.ctc import CTC
 from funasr.utils import postprocess_utils
 from funasr.metrics.compute_acc import th_accuracy
+from funasr.train_utils.device_funcs import to_device
 from funasr.utils.datadir_writer import DatadirWriter
 from funasr.models.paraformer.search import Hypothesis
 from funasr.models.paraformer.cif_predictor import mae_loss
 from funasr.train_utils.device_funcs import force_gatherable
 from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
 from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
-from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-from funasr.train_utils.device_funcs import to_device
+
 
 @tables.register("model_classes", "Paraformer")
 class Paraformer(torch.nn.Module):
@@ -549,78 +550,10 @@
                 
         return results, meta_data
 
-    def export(
-        self,
-        max_seq_len=512,
-        **kwargs,
-    ):
-        self.device = kwargs.get("device")
-        is_onnx = kwargs.get("type", "onnx") == "onnx"
-        encoder_class = tables.encoder_classes.get(kwargs["encoder"]+"Export")
-        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
-        
-        predictor_class = tables.predictor_classes.get(kwargs["predictor"]+"Export")
-        self.predictor = predictor_class(self.predictor, onnx=is_onnx)
+    def export(self, **kwargs):
+        from .export_meta import export_rebuild_model
+        if 'max_seq_len' not in kwargs:
+            kwargs['max_seq_len'] = 512
+        models = export_rebuild_model(model=self, **kwargs)
+        return models
 
-
-        decoder_class = tables.decoder_classes.get(kwargs["decoder"]+"Export")
-        self.decoder = decoder_class(self.decoder, onnx=is_onnx)
-        
-        from funasr.utils.torch_function import sequence_mask
-
-
-        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-        
-        self.forward = self.export_forward
-        
-        return self
-
-    def export_forward(
-        self,
-        speech: torch.Tensor,
-        speech_lengths: torch.Tensor,
-    ):
-        # a. To device
-        batch = {"speech": speech, "speech_lengths": speech_lengths}
-        batch = to_device(batch, device=self.device)
-    
-        enc, enc_len = self.encoder(**batch)
-        mask = self.make_pad_mask(enc_len)[:, None, :]
-        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
-        pre_token_length = pre_token_length.floor().type(torch.int32)
-    
-        decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
-        decoder_out = torch.log_softmax(decoder_out, dim=-1)
-        # sample_ids = decoder_out.argmax(dim=-1)
-    
-        return decoder_out, pre_token_length
-
-    def export_dummy_inputs(self):
-        speech = torch.randn(2, 30, 560)
-        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
-        return (speech, speech_lengths)
-
-
-    def export_input_names(self):
-        return ['speech', 'speech_lengths']
-
-    def export_output_names(self):
-        return ['logits', 'token_num']
-
-    def export_dynamic_axes(self):
-        return {
-            'speech': {
-                0: 'batch_size',
-                1: 'feats_length'
-            },
-            'speech_lengths': {
-                0: 'batch_size',
-            },
-            'logits': {
-                0: 'batch_size',
-                1: 'logits_length'
-            },
-        }
-
-    def export_name(self, ):
-        return "model.onnx"

--
Gitblit v1.9.1