From 6427c834dfd97b1f05c6659cdc7ccf010bf82fe1 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期一, 24 四月 2023 19:50:07 +0800
Subject: [PATCH] update
---
funasr/export/models/modules/encoder_layer.py | 60 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++---
1 files changed, 57 insertions(+), 3 deletions(-)
diff --git a/funasr/export/models/modules/encoder_layer.py b/funasr/export/models/modules/encoder_layer.py
index 800a4f7..7d01397 100644
--- a/funasr/export/models/modules/encoder_layer.py
+++ b/funasr/export/models/modules/encoder_layer.py
@@ -16,6 +16,7 @@
self.feed_forward = model.feed_forward
self.norm1 = model.norm1
self.norm2 = model.norm2
+ self.in_size = model.in_size
self.size = model.size
def forward(self, x, mask):
@@ -23,15 +24,68 @@
residual = x
x = self.norm1(x)
x = self.self_attn(x, mask)
- if x.size(2) == residual.size(2):
+ if self.in_size == self.size:
x = x + residual
residual = x
x = self.norm2(x)
x = self.feed_forward(x)
- if x.size(2) == residual.size(2):
- x = x + residual
+ x = x + residual
return x, mask
+class EncoderLayerConformer(nn.Module):
+ def __init__(
+ self,
+ model,
+ ):
+ """Construct an EncoderLayer object."""
+ super().__init__()
+ self.self_attn = model.self_attn
+ self.feed_forward = model.feed_forward
+ self.feed_forward_macaron = model.feed_forward_macaron
+ self.conv_module = model.conv_module
+ self.norm_ff = model.norm_ff
+ self.norm_mha = model.norm_mha
+ self.norm_ff_macaron = model.norm_ff_macaron
+ self.norm_conv = model.norm_conv
+ self.norm_final = model.norm_final
+ self.size = model.size
+ def forward(self, x, mask):
+ if isinstance(x, tuple):
+ x, pos_emb = x[0], x[1]
+ else:
+ x, pos_emb = x, None
+
+ if self.feed_forward_macaron is not None:
+ residual = x
+ x = self.norm_ff_macaron(x)
+ x = residual + self.feed_forward_macaron(x) * 0.5
+
+ residual = x
+ x = self.norm_mha(x)
+
+ x_q = x
+
+ if pos_emb is not None:
+ x_att = self.self_attn(x_q, x, x, pos_emb, mask)
+ else:
+ x_att = self.self_attn(x_q, x, x, mask)
+ x = residual + x_att
+
+ if self.conv_module is not None:
+ residual = x
+ x = self.norm_conv(x)
+ x = residual + self.conv_module(x)
+
+ residual = x
+ x = self.norm_ff(x)
+ x = residual + self.feed_forward(x) * 0.5
+
+ x = self.norm_final(x)
+
+ if pos_emb is not None:
+ return (x, pos_emb), mask
+
+ return x, mask
--
Gitblit v1.9.1