From 57f2a51f9ae2c7c9951f137f3d247cff47100944 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期一, 27 二月 2023 16:55:06 +0800
Subject: [PATCH] onnx supports tiny and bicif paraformer
---
funasr/export/models/modules/encoder_layer.py | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 54 insertions(+), 0 deletions(-)
diff --git a/funasr/export/models/modules/encoder_layer.py b/funasr/export/models/modules/encoder_layer.py
index 800a4f7..622b109 100644
--- a/funasr/export/models/modules/encoder_layer.py
+++ b/funasr/export/models/modules/encoder_layer.py
@@ -34,4 +34,58 @@
return x, mask
+class EncoderLayerConformer(nn.Module):
+ def __init__(
+ self,
+ model,
+ ):
+ """Construct an EncoderLayer object."""
+ super().__init__()
+ self.self_attn = model.self_attn
+ self.feed_forward = model.feed_forward
+ self.feed_forward_macaron = model.feed_forward_macaron
+ self.conv_module = model.conv_module
+ self.norm_ff = model.norm_ff
+ self.norm_mha = model.norm_mha
+ self.norm_ff_macaron = model.norm_ff_macaron
+ self.norm_conv = model.norm_conv
+ self.norm_final = model.norm_final
+ self.size = model.size
+ def forward(self, x, mask):
+ if isinstance(x, tuple):
+ x, pos_emb = x[0], x[1]
+ else:
+ x, pos_emb = x, None
+
+ if self.feed_forward_macaron is not None:
+ residual = x
+ x = self.norm_ff_macaron(x)
+ x = residual + self.feed_forward_macaron(x)
+
+ residual = x
+ x = self.norm_mha(x)
+
+ x_q = x
+
+ if pos_emb is not None:
+ x_att = self.self_attn(x_q, x, x, pos_emb, mask)
+ else:
+ x_att = self.self_attn(x_q, x, x, mask)
+ x = residual + x_att
+
+ if self.conv_module is not None:
+ residual = x
+ x = self.norm_conv(x)
+ x = residual + self.conv_module(x)
+
+ residual = x
+ x = self.norm_ff(x)
+ x = residual + self.feed_forward(x)
+
+ x = self.norm_final(x)
+
+ if pos_emb is not None:
+ return (x, pos_emb), mask
+
+ return x, mask
--
Gitblit v1.9.1