From 57f2a51f9ae2c7c9951f137f3d247cff47100944 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期一, 27 二月 2023 16:55:06 +0800
Subject: [PATCH] onnx supports tiny and bicif paraformer

---
 funasr/export/models/modules/encoder_layer.py |   54 ++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 54 insertions(+), 0 deletions(-)

diff --git a/funasr/export/models/modules/encoder_layer.py b/funasr/export/models/modules/encoder_layer.py
index 800a4f7..622b109 100644
--- a/funasr/export/models/modules/encoder_layer.py
+++ b/funasr/export/models/modules/encoder_layer.py
@@ -34,4 +34,58 @@
         return x, mask
 
 
+class EncoderLayerConformer(nn.Module):
+    def __init__(
+        self,
+        model,
+    ):
+        """Construct an EncoderLayer object."""
+        super().__init__()
+        self.self_attn = model.self_attn
+        self.feed_forward = model.feed_forward
+        self.feed_forward_macaron = model.feed_forward_macaron
+        self.conv_module = model.conv_module
+        self.norm_ff = model.norm_ff
+        self.norm_mha = model.norm_mha
+        self.norm_ff_macaron = model.norm_ff_macaron
+        self.norm_conv = model.norm_conv
+        self.norm_final = model.norm_final
+        self.size = model.size
 
+    def forward(self, x, mask):
+        if isinstance(x, tuple):
+            x, pos_emb = x[0], x[1]
+        else:
+            x, pos_emb = x, None
+
+        if self.feed_forward_macaron is not None:
+            residual = x
+            x = self.norm_ff_macaron(x)
+            x = residual + self.feed_forward_macaron(x)
+
+        residual = x
+        x = self.norm_mha(x)
+
+        x_q = x
+
+        if pos_emb is not None:
+            x_att = self.self_attn(x_q, x, x, pos_emb, mask)
+        else:
+            x_att = self.self_attn(x_q, x, x, mask)
+        x = residual + x_att
+
+        if self.conv_module is not None:
+            residual = x
+            x = self.norm_conv(x)
+            x = residual +  self.conv_module(x)
+
+        residual = x
+        x = self.norm_ff(x)
+        x = residual + self.feed_forward(x)
+
+        x = self.norm_final(x)
+
+        if pos_emb is not None:
+            return (x, pos_emb), mask
+
+        return x, mask

--
Gitblit v1.9.1