From 38de2af5bf9976d2f14f087d9a0d31991daf6783 Mon Sep 17 00:00:00 2001
From: Zhihao Du <neo.dzh@alibaba-inc.com>
Date: 星期四, 16 三月 2023 19:41:34 +0800
Subject: [PATCH] Merge branch 'main' into dev_dzh

---
 funasr/export/models/modules/encoder_layer.py |    6 +++---
 1 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/funasr/export/models/modules/encoder_layer.py b/funasr/export/models/modules/encoder_layer.py
index d132574..7d01397 100644
--- a/funasr/export/models/modules/encoder_layer.py
+++ b/funasr/export/models/modules/encoder_layer.py
@@ -16,6 +16,7 @@
         self.feed_forward = model.feed_forward
         self.norm1 = model.norm1
         self.norm2 = model.norm2
+        self.in_size = model.in_size
         self.size = model.size
 
     def forward(self, x, mask):
@@ -23,13 +24,12 @@
         residual = x
         x = self.norm1(x)
         x = self.self_attn(x, mask)
-        if x.size(2) == residual.size(2):
+        if self.in_size == self.size:
             x = x + residual
         residual = x
         x = self.norm2(x)
         x = self.feed_forward(x)
-        if x.size(2) == residual.size(2):
-            x = x + residual
+        x = x + residual
 
         return x, mask
 

--
Gitblit v1.9.1