From 95cf2646fa6dae67bf53354f4ed5e81780d8fee9 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 11 三月 2024 14:43:08 +0800
Subject: [PATCH] onnx (#1460)
---
funasr/models/transformer/decoder.py | 29 +++++++++++++++++++++++++++++
1 files changed, 29 insertions(+), 0 deletions(-)
diff --git a/funasr/models/transformer/decoder.py b/funasr/models/transformer/decoder.py
index 820de4a..1e88a25 100644
--- a/funasr/models/transformer/decoder.py
+++ b/funasr/models/transformer/decoder.py
@@ -150,6 +150,35 @@
return x, tgt_mask, memory, memory_mask
+class DecoderLayerExport(nn.Module):
+ def __init__(self, model):
+ super().__init__()
+ self.self_attn = model.self_attn
+ self.src_attn = model.src_attn
+ self.feed_forward = model.feed_forward
+ self.norm1 = model.norm1
+ self.norm2 = model.norm2
+ self.norm3 = model.norm3
+
+ def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
+ residual = tgt
+ tgt = self.norm1(tgt)
+ tgt_q = tgt
+ tgt_q_mask = tgt_mask
+ x = residual + self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)
+
+ residual = x
+ x = self.norm2(x)
+
+ x = residual + self.src_attn(x, memory, memory, memory_mask)
+
+ residual = x
+ x = self.norm3(x)
+ x = residual + self.feed_forward(x)
+
+ return x, tgt_mask, memory, memory_mask
+
+
class BaseTransformerDecoder(nn.Module, BatchScorerInterface):
"""Base class of Transfomer decoder module.
--
Gitblit v1.9.1