From d14bb0f843aea0aeb254a5b3c21d42d04e28765b Mon Sep 17 00:00:00 2001
From: nichongjia-2007 <nichongjia@gmail.com>
Date: 星期五, 14 七月 2023 17:43:40 +0800
Subject: [PATCH] add conformer export

---
 funasr/export/models/e2e_asr_conformer.py |   85 +++++++++++++-----------------------------
 1 files changed, 26 insertions(+), 59 deletions(-)

diff --git a/funasr/export/models/e2e_asr_conformer.py b/funasr/export/models/e2e_asr_conformer.py
index 69907fb..45feda5 100644
--- a/funasr/export/models/e2e_asr_conformer.py
+++ b/funasr/export/models/e2e_asr_conformer.py
@@ -1,3 +1,4 @@
+import os
 import logging
 import torch
 import torch.nn as nn
@@ -5,6 +6,7 @@
 from funasr.export.utils.torch_function import MakePadMask
 from funasr.export.utils.torch_function import sequence_mask
 from funasr.models.encoder.conformer_encoder import ConformerEncoder
+from funasr.models.decoder.transformer_decoder import TransformerDecoder
 from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
 from funasr.export.models.decoder.xformer_decoder import XformerDecoder as TransformerDecoder_export
 
@@ -18,7 +20,6 @@
             model,
             max_seq_len=512,
             feats_dim=560,
-            output_size=2048,
             model_name='model',
             **kwargs,
     ):
@@ -32,71 +33,37 @@
             self.decoder = TransformerDecoder_export(model.decoder, onnx=onnx)
         
         self.feats_dim = feats_dim
-        self.output_size = output_size
         self.model_name = model_name
 
         if onnx:
             self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
         else:
             self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
-        
-    def forward(
-            self,
-            speech: torch.Tensor,
-            speech_lengths: torch.Tensor,
-    ):
-        # a. To device
-        batch = {"speech": speech, "speech_lengths": speech_lengths}
-        # batch = to_device(batch, device=self.device)
-    
-        enc, enc_len = self.encoder(**batch)
-        mask = self.make_pad_mask(enc_len)[:, None, :]
 
-        # fill the decoder input
-        enc_size = self.encoder.output_size
-        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
-        cache_num = len(self.model.decoder)
-        cache = [
-            torch.zeros((1, self.decoder.size, self.decoder.self_attn.kernel_size))
-            for _ in range(cache_num)
-        ]
+    def _export_model(self, model, verbose, path):
+        dummy_input = model.get_dummy_inputs()
+        model_script = model
+        model_path = os.path.join(path, f'{model.model_name}.onnx')
+        if not os.path.exists(model_path):
+            torch.onnx.export(
+                model_script,
+                dummy_input,
+                model_path,
+                verbose=verbose,
+                opset_version=14,
+                input_names=model.get_input_names(),
+                output_names=model.get_output_names(),
+                dynamic_axes=model.get_dynamic_axes()
+            )
 
-        decoder_out, olens = self.decoder(enc, enc_len, pre_acoustic_embeds, cache)
-        decoder_out = torch.log_softmax(decoder_out, dim=-1)
-        # sample_ids = decoder_out.argmax(dim=-1)
+    def _export_encoder_onnx(self, verbose, path):
+        model_encoder = self.encoder
+        self._export_model(model_encoder, verbose, path)
 
-        return decoder_out, olens
+    def _export_decoder_onnx(self, verbose, path):
+        model_decoder = self.decoder
+        self._export_model(model_decoder, verbose, path)
 
-    def get_dummy_inputs(self):
-        speech = torch.randn(2, 30, self.feats_dim)
-        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
-        return (speech, speech_lengths)
-
-    def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
-        import numpy as np
-        fbank = np.loadtxt(txt_file)
-        fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
-        speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
-        speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
-        return (speech, speech_lengths)
-
-    def get_input_names(self):
-        return ['speech', 'speech_lengths']
-
-    def get_output_names(self):
-        return ['logits', 'token_num']
-
-    def get_dynamic_axes(self):
-        return {
-            'speech': {
-                0: 'batch_size',
-                1: 'feats_length'
-            },
-            'speech_lengths': {
-                0: 'batch_size',
-            },
-            'logits': {
-                0: 'batch_size',
-                1: 'logits_length'
-            },
-        }
+    def _export_onnx(self, verbose, path):
+        self._export_encoder_onnx(verbose, path)
+        self._export_decoder_onnx(verbose, path)
\ No newline at end of file

--
Gitblit v1.9.1