From 9b4e9cc8a0311e5243d69b73ed073e7ea441982e Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 27 三月 2024 16:05:29 +0800
Subject: [PATCH] train update

---
 funasr/models/transformer/decoder.py |   41 +++++++++++++++++++++++++++++++++++------
 1 files changed, 35 insertions(+), 6 deletions(-)

diff --git a/funasr/models/transformer/decoder.py b/funasr/models/transformer/decoder.py
index 3e8d224..1e88a25 100644
--- a/funasr/models/transformer/decoder.py
+++ b/funasr/models/transformer/decoder.py
@@ -26,7 +26,7 @@
 from funasr.models.transformer.utils.repeat import repeat
 from funasr.models.transformer.scorers.scorer_interface import BatchScorerInterface
 
-from funasr.utils.register import register_class, registry_tables
+from funasr.register import tables
 
 class DecoderLayer(nn.Module):
     """Single decoder layer module.
@@ -147,6 +147,35 @@
         if cache is not None:
             x = torch.cat([cache, x], dim=1)
 
+        return x, tgt_mask, memory, memory_mask
+
+
+class DecoderLayerExport(nn.Module):
+    def __init__(self, model):
+        super().__init__()
+        self.self_attn = model.self_attn
+        self.src_attn = model.src_attn
+        self.feed_forward = model.feed_forward
+        self.norm1 = model.norm1
+        self.norm2 = model.norm2
+        self.norm3 = model.norm3
+    
+    def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
+        residual = tgt
+        tgt = self.norm1(tgt)
+        tgt_q = tgt
+        tgt_q_mask = tgt_mask
+        x = residual + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)
+        
+        residual = x
+        x = self.norm2(x)
+        
+        x = residual + self.src_attn(x, memory, memory, memory_mask)
+        
+        residual = x
+        x = self.norm3(x)
+        x = residual + self.feed_forward(x)
+        
         return x, tgt_mask, memory, memory_mask
 
 
@@ -352,7 +381,7 @@
         state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
         return logp, state_list
 
-@register_class("decoder_classes", "TransformerDecoder")
+@tables.register("decoder_classes", "TransformerDecoder")
 class TransformerDecoder(BaseTransformerDecoder):
     def __init__(
             self,
@@ -401,7 +430,7 @@
         )
 
 
-@register_class("decoder_classes", "LightweightConvolutionTransformerDecoder")
+@tables.register("decoder_classes", "LightweightConvolutionTransformerDecoder")
 class LightweightConvolutionTransformerDecoder(BaseTransformerDecoder):
     def __init__(
             self,
@@ -462,7 +491,7 @@
             ),
         )
 
-@register_class("decoder_classes", "LightweightConvolution2DTransformerDecoder")
+@tables.register("decoder_classes", "LightweightConvolution2DTransformerDecoder")
 class LightweightConvolution2DTransformerDecoder(BaseTransformerDecoder):
     def __init__(
             self,
@@ -524,7 +553,7 @@
         )
 
 
-@register_class("decoder_classes", "DynamicConvolutionTransformerDecoder")
+@tables.register("decoder_classes", "DynamicConvolutionTransformerDecoder")
 class DynamicConvolutionTransformerDecoder(BaseTransformerDecoder):
     def __init__(
             self,
@@ -585,7 +614,7 @@
             ),
         )
 
-@register_class("decoder_classes", "DynamicConvolution2DTransformerDecoder")
+@tables.register("decoder_classes", "DynamicConvolution2DTransformerDecoder")
 class DynamicConvolution2DTransformerDecoder(BaseTransformerDecoder):
     def __init__(
             self,

--
Gitblit v1.9.1