shixian.shi
2023-02-27 57f2a51f9ae2c7c9951f137f3d247cff47100944
funasr/export/models/modules/decoder_layer.py
@@ -41,3 +41,30 @@
        return x, tgt_mask, memory, memory_mask, cache
class DecoderLayer(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.self_attn = model.self_attn
        self.src_attn = model.src_attn
        self.feed_forward = model.feed_forward
        self.norm1 = model.norm1
        self.norm2 = model.norm2
        self.norm3 = model.norm3
    def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
        residual = tgt
        tgt_q = tgt
        tgt_q_mask = tgt_mask
        x = residual + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)
        residual = x
        x = self.norm2(x)
        x = residual + self.src_attn(x, memory, memory, memory_mask)
        residual = x
        x = self.norm3(x)
        x = residual + self.feed_forward(x)
        return x, tgt_mask, memory, memory_mask