From e04489ce4c0fd0095d0c79ef8f504f425e0435a8 Mon Sep 17 00:00:00 2001
From: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
Date: 星期三, 13 三月 2024 16:34:42 +0800
Subject: [PATCH] contextual&seaco ONNX export (#1481)
---
funasr/models/paraformer/decoder.py | 76 ++++++++++++++++++++++++++++++++++++--
1 files changed, 72 insertions(+), 4 deletions(-)
diff --git a/funasr/models/paraformer/decoder.py b/funasr/models/paraformer/decoder.py
index 7c370ba..f08e97b 100644
--- a/funasr/models/paraformer/decoder.py
+++ b/funasr/models/paraformer/decoder.py
@@ -616,6 +616,22 @@
return x, tgt_mask, memory, memory_mask, cache
+
+ def get_attn_mat(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+ residual = tgt
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.self_attn is not None:
+ tgt = self.norm2(tgt)
+ x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+ x = residual + x
+
+ residual = x
+ x = self.norm3(x)
+ x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
+ return attn_mat
@tables.register("decoder_classes", "ParaformerSANMDecoderExport")
@@ -675,6 +691,8 @@
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
+ return_hidden: bool = False,
+ return_both: bool = False,
):
tgt = ys_in_pad
@@ -698,11 +716,60 @@
x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
x, tgt_mask, memory, memory_mask
)
- x = self.after_norm(x)
- x = self.output_layer(x)
+ hidden = self.after_norm(x)
+ # x = self.output_layer(x)
- return x, ys_in_lens
+ if self.output_layer is not None and return_hidden is False:
+ x = self.output_layer(hidden)
+ return x, ys_in_lens
+ if return_both:
+ x = self.output_layer(hidden)
+ return x, hidden, ys_in_lens
+ return hidden, ys_in_lens
+ def forward_asf2(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ ):
+
+ tgt = ys_in_pad
+ tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+ _, memory_mask = self.prepare_mask(memory_mask)
+
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](tgt, tgt_mask, memory, memory_mask)
+ attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
+ return attn_mat
+
+ def forward_asf6(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ ):
+
+ tgt = ys_in_pad
+ tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+ _, memory_mask = self.prepare_mask(memory_mask)
+
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](tgt, tgt_mask, memory, memory_mask)
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[1](tgt, tgt_mask, memory, memory_mask)
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[2](tgt, tgt_mask, memory, memory_mask)
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[3](tgt, tgt_mask, memory, memory_mask)
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[4](tgt, tgt_mask, memory, memory_mask)
+ attn_mat = self.model.decoders[5].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
+ return attn_mat
+
+ '''
def get_dummy_inputs(self, enc_size):
tgt = torch.LongTensor([0]).unsqueeze(0)
memory = torch.randn(1, 100, enc_size)
@@ -751,7 +818,8 @@
for d in range(cache_num)
})
return ret
-
+ '''
+
@tables.register("decoder_classes", "ParaformerSANMDecoderOnlineExport")
class ParaformerSANMDecoderOnlineExport(torch.nn.Module):
def __init__(self, model,
--
Gitblit v1.9.1