From 69ccdd35cda4c8482e189fa350fbcb83997872f2 Mon Sep 17 00:00:00 2001
From: wanchen.swc <wanchen.swc@alibaba-inc.com>
Date: 星期一, 06 三月 2023 18:18:31 +0800
Subject: [PATCH] [Quantization] model quantization for inference
---
funasr/export/models/modules/encoder_layer.py | 6 +++---
1 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/funasr/export/models/modules/encoder_layer.py b/funasr/export/models/modules/encoder_layer.py
index 622b109..1da05f3 100644
--- a/funasr/export/models/modules/encoder_layer.py
+++ b/funasr/export/models/modules/encoder_layer.py
@@ -16,6 +16,7 @@
self.feed_forward = model.feed_forward
self.norm1 = model.norm1
self.norm2 = model.norm2
+ self.in_size = model.in_size
self.size = model.size
def forward(self, x, mask):
@@ -23,13 +24,12 @@
residual = x
x = self.norm1(x)
x = self.self_attn(x, mask)
- if x.size(2) == residual.size(2):
+ if self.in_size == self.size:
x = x + residual
residual = x
x = self.norm2(x)
x = self.feed_forward(x)
- if x.size(2) == residual.size(2):
- x = x + residual
+ x = x + residual
return x, mask
--
Gitblit v1.9.1