From b8bf792ce7df411ae4ed8d2bd8c8eba7c59e082b Mon Sep 17 00:00:00 2001
From: 维石 <shixian.shi@alibaba-inc.com>
Date: 星期三, 10 四月 2024 11:37:27 +0800
Subject: [PATCH] fix bug

---
 funasr/models/transformer/decoder.py |   29 +++++++++++++++++++++++++++++
 1 files changed, 29 insertions(+), 0 deletions(-)

diff --git a/funasr/models/transformer/decoder.py b/funasr/models/transformer/decoder.py
index 820de4a..1e88a25 100644
--- a/funasr/models/transformer/decoder.py
+++ b/funasr/models/transformer/decoder.py
@@ -150,6 +150,35 @@
         return x, tgt_mask, memory, memory_mask
 
 
+class DecoderLayerExport(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.self_attn = model.self_attn
+        self.src_attn = model.src_attn
+        self.feed_forward = model.feed_forward
+        self.norm1 = model.norm1
+        self.norm2 = model.norm2
+        self.norm3 = model.norm3
+    
+    def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
+        residual = tgt
+        tgt = self.norm1(tgt)
+        tgt_q = tgt
+        tgt_q_mask = tgt_mask
+        x = residual + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)
+        
+        residual = x
+        x = self.norm2(x)
+        
+        x = residual + self.src_attn(x, memory, memory, memory_mask)
+        
+        residual = x
+        x = self.norm3(x)
+        x = residual + self.feed_forward(x)
+        
+        return x, tgt_mask, memory, memory_mask
+
+
 class BaseTransformerDecoder(nn.Module, BatchScorerInterface):
     """Base class of Transfomer decoder module.
 

--
Gitblit v1.9.1