From c542eacb0aadcbc49c63db40429fca4e08f807a4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 21 七月 2023 10:27:35 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/export/models/e2e_asr_conformer.py |   69 ++++++++++++++++++++++++++++++++++
 1 files changed, 69 insertions(+), 0 deletions(-)

diff --git a/funasr/export/models/e2e_asr_conformer.py b/funasr/export/models/e2e_asr_conformer.py
new file mode 100644
index 0000000..45feda5
--- /dev/null
+++ b/funasr/export/models/e2e_asr_conformer.py
@@ -0,0 +1,69 @@
+import os
+import logging
+import torch
+import torch.nn as nn
+
+from funasr.export.utils.torch_function import MakePadMask
+from funasr.export.utils.torch_function import sequence_mask
+from funasr.models.encoder.conformer_encoder import ConformerEncoder
+from funasr.models.decoder.transformer_decoder import TransformerDecoder
+from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
+from funasr.export.models.decoder.xformer_decoder import XformerDecoder as TransformerDecoder_export
+
+class Conformer(nn.Module):
+    """
+    export conformer into onnx format
+    """
+
+    def __init__(
+            self,
+            model,
+            max_seq_len=512,
+            feats_dim=560,
+            model_name='model',
+            **kwargs,
+    ):
+        super().__init__()
+        onnx = False
+        if "onnx" in kwargs:
+            onnx = kwargs["onnx"]
+        if isinstance(model.encoder, ConformerEncoder):
+            self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
+        elif isinstance(model.decoder, TransformerDecoder):
+            self.decoder = TransformerDecoder_export(model.decoder, onnx=onnx)
+        
+        self.feats_dim = feats_dim
+        self.model_name = model_name
+
+        if onnx:
+            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        else:
+            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+    def _export_model(self, model, verbose, path):
+        dummy_input = model.get_dummy_inputs()
+        model_script = model
+        model_path = os.path.join(path, f'{model.model_name}.onnx')
+        if not os.path.exists(model_path):
+            torch.onnx.export(
+                model_script,
+                dummy_input,
+                model_path,
+                verbose=verbose,
+                opset_version=14,
+                input_names=model.get_input_names(),
+                output_names=model.get_output_names(),
+                dynamic_axes=model.get_dynamic_axes()
+            )
+
+    def _export_encoder_onnx(self, verbose, path):
+        model_encoder = self.encoder
+        self._export_model(model_encoder, verbose, path)
+
+    def _export_decoder_onnx(self, verbose, path):
+        model_decoder = self.decoder
+        self._export_model(model_decoder, verbose, path)
+
+    def _export_onnx(self, verbose, path):
+        self._export_encoder_onnx(verbose, path)
+        self._export_decoder_onnx(verbose, path)
\ No newline at end of file

--
Gitblit v1.9.1