From 6850e9f86a2c478455907eb575748eb9c45cbddc Mon Sep 17 00:00:00 2001
From: nichongjia-2007 <nichongjia@gmail.com>
Date: 星期二, 11 七月 2023 17:36:11 +0800
Subject: [PATCH] add export conformer

---
 funasr/export/models/e2e_asr_conformer.py       |    3 -
 funasr/export/models/decoder/xformer_decoder.py |  103 +++++++++++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 104 insertions(+), 2 deletions(-)

diff --git a/funasr/export/models/decoder/xformer_decoder.py b/funasr/export/models/decoder/xformer_decoder.py
new file mode 100644
index 0000000..29837e1
--- /dev/null
+++ b/funasr/export/models/decoder/xformer_decoder.py
@@ -0,0 +1,103 @@
+import os
+
+import torch
+import torch.nn as nn
+
+from funasr.modules.attention import MultiHeadedAttention
+
+from funasr.export.models.modules.decoder_layer import DecoderLayer as OnnxDecoderLayer
+from funasr.export.models.language_models.embed import Embedding
+from funasr.export.models.modules.multihead_att import \
+    OnnxMultiHeadedAttention
+
+from funasr.export.utils.torch_function import MakePadMask, subsequent_mask
+
+class XformerDecoder(nn.Module):
+    def __init__(self, model, max_seq_len=512, **kwargs):
+        super().__init__()
+        self.embed = Embedding(model.embed, max_seq_len)
+        self.model = model
+        self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+        if isinstance(self.model.decoders[0].self_attn, MultiHeadedAttention):
+            self.num_heads = self.model.decoders[0].self_attn.h
+            self.hidden_size = self.model.decoders[0].self_attn.linear_out.out_features
+
+        # replace multihead attention module into customized module.
+        for i, d in enumerate(self.model.decoders):
+            # d is DecoderLayer
+            if isinstance(d.self_attn, MultiHeadedAttention):
+                d.self_attn = OnnxMultiHeadedAttention(d.self_attn)
+            if isinstance(d.src_attn, MultiHeadedAttention):
+                d.src_attn = OnnxMultiHeadedAttention(d.src_attn)
+            self.model.decoders[i] = OnnxDecoderLayer(d)
+
+        self.model_name = "xformer_decoder"
+
+    def prepare_mask(self, mask):
+        if len(mask.shape) == 2:
+            mask = mask[:, None, None, :]
+        elif len(mask.shape) == 3:
+            mask = mask[:, None, :]
+        mask = 1 - mask
+        return mask * -10000.0
+
+    def forward(self, tgt, memory, cache):
+        mask = subsequent_mask(tgt.size(-1)).unsqueeze(0)  # (B, T)
+
+        x = self.embed(tgt)
+        mask = self.prepare_mask(mask)
+        new_cache = []
+        for c, decoder in zip(cache, self.model.decoders):
+            x, mask = decoder(x, mask, memory, None, c)
+            new_cache.append(x)
+            x = x[:, 1:, :]
+
+        if self.model.normalize_before:
+            y = self.model.after_norm(x[:, -1])
+        else:
+            y = x[:, -1]
+
+        if self.model.output_layer is not None:
+            y = torch.log_softmax(self.model.output_layer(y), dim=-1)
+        return y, new_cache
+
+    def get_dummy_inputs(self, enc_size):
+        tgt = torch.LongTensor([0]).unsqueeze(0)
+        enc_out = torch.randn(1, 100, enc_size)
+        cache = [
+            torch.zeros((1, 1, self.model.decoders[0].size))
+            for _ in range(len(self.model.decoders))
+        ]
+        return (tgt, enc_out, cache)
+
+    def is_optimizable(self):
+        return True
+
+    def get_input_names(self):
+        return ["tgt", "memory"] + [
+            "cache_%d" % i for i in range(len(self.model.decoders))
+        ]
+
+    def get_output_names(self):
+        return ["y"] + ["out_cache_%d" % i for i in range(len(self.model.decoders))]
+
+    def get_dynamic_axes(self):
+        ret = {
+            "tgt": {0: "tgt_batch", 1: "tgt_length"},
+            "memory": {0: "memory_batch", 1: "memory_length"},
+        }
+        ret.update(
+            {
+                "cache_%d" % d: {0: "cache_%d_batch" % d, 1: "cache_%d_length" % d}
+                for d in range(len(self.model.decoders))
+            }
+        )
+        return ret
+
+    def get_model_config(self, path):
+        return {
+            "dec_type": "XformerDecoder",
+            "model_path": os.path.join(path, f"{self.model_name}.onnx"),
+            "n_layers": len(self.model.decoders),
+            "odim": self.model.decoders[0].size,
+        }
diff --git a/funasr/export/models/e2e_asr_conformer.py b/funasr/export/models/e2e_asr_conformer.py
index 49c9aae..69907fb 100644
--- a/funasr/export/models/e2e_asr_conformer.py
+++ b/funasr/export/models/e2e_asr_conformer.py
@@ -6,8 +6,7 @@
 from funasr.export.utils.torch_function import sequence_mask
 from funasr.models.encoder.conformer_encoder import ConformerEncoder
 from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
-from funasr.models.decoder.transformer_decoder import TransformerDecoder as TransformerDecoder_export
-
+from funasr.export.models.decoder.xformer_decoder import XformerDecoder as TransformerDecoder_export
 
 class Conformer(nn.Module):
     """

--
Gitblit v1.9.1