From 9b4e9cc8a0311e5243d69b73ed073e7ea441982e Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 27 三月 2024 16:05:29 +0800
Subject: [PATCH] train update
---
funasr/models/paraformer/decoder.py | 104 +++++++++++++++++++++++++++++++++++++++++----------
1 files changed, 83 insertions(+), 21 deletions(-)
diff --git a/funasr/models/paraformer/decoder.py b/funasr/models/paraformer/decoder.py
index 59c6e1d..b75d21d 100644
--- a/funasr/models/paraformer/decoder.py
+++ b/funasr/models/paraformer/decoder.py
@@ -351,9 +351,9 @@
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
- return_hidden: bool = False,
- return_both: bool= False,
chunk_mask: torch.Tensor = None,
+ return_hidden: bool = False,
+ return_both: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
@@ -616,6 +616,22 @@
return x, tgt_mask, memory, memory_mask, cache
+
+ def get_attn_mat(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+ residual = tgt
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.self_attn is not None:
+ tgt = self.norm2(tgt)
+ x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+ x = residual + x
+
+ residual = x
+ x = self.norm3(x)
+ x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
+ return attn_mat
@tables.register("decoder_classes", "ParaformerSANMDecoderExport")
@@ -628,14 +644,12 @@
):
super().__init__()
# self.embed = model.embed #Embedding(model.embed, max_seq_len)
- from funasr.utils.torch_function import MakePadMask
+
from funasr.utils.torch_function import sequence_mask
self.model = model
- if onnx:
- self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
- else:
- self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
@@ -677,6 +691,8 @@
hlens: torch.Tensor,
ys_in_pad: torch.Tensor,
ys_in_lens: torch.Tensor,
+ return_hidden: bool = False,
+ return_both: bool = False,
):
tgt = ys_in_pad
@@ -700,11 +716,60 @@
x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
x, tgt_mask, memory, memory_mask
)
- x = self.after_norm(x)
- x = self.output_layer(x)
+ hidden = self.after_norm(x)
+ # x = self.output_layer(x)
- return x, ys_in_lens
+ if self.output_layer is not None and return_hidden is False:
+ x = self.output_layer(hidden)
+ return x, ys_in_lens
+ if return_both:
+ x = self.output_layer(hidden)
+ return x, hidden, ys_in_lens
+ return hidden, ys_in_lens
+ def forward_asf2(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ ):
+
+ tgt = ys_in_pad
+ tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+ _, memory_mask = self.prepare_mask(memory_mask)
+
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](tgt, tgt_mask, memory, memory_mask)
+ attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
+ return attn_mat
+
+ def forward_asf6(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ ):
+
+ tgt = ys_in_pad
+ tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+ _, memory_mask = self.prepare_mask(memory_mask)
+
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[0](tgt, tgt_mask, memory, memory_mask)
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[1](tgt, tgt_mask, memory, memory_mask)
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[2](tgt, tgt_mask, memory, memory_mask)
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[3](tgt, tgt_mask, memory, memory_mask)
+ tgt, tgt_mask, memory, memory_mask, _ = self.model.decoders[4](tgt, tgt_mask, memory, memory_mask)
+ attn_mat = self.model.decoders[5].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
+ return attn_mat
+
+ '''
def get_dummy_inputs(self, enc_size):
tgt = torch.LongTensor([0]).unsqueeze(0)
memory = torch.randn(1, 100, enc_size)
@@ -753,7 +818,8 @@
for d in range(cache_num)
})
return ret
-
+ '''
+
@tables.register("decoder_classes", "ParaformerSANMDecoderOnlineExport")
class ParaformerSANMDecoderOnlineExport(torch.nn.Module):
def __init__(self, model,
@@ -763,14 +829,12 @@
super().__init__()
# self.embed = model.embed #Embedding(model.embed, max_seq_len)
self.model = model
- from funasr.utils.torch_function import MakePadMask
+
from funasr.utils.torch_function import sequence_mask
self.model = model
- if onnx:
- self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
- else:
- self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
@@ -1036,14 +1100,12 @@
# self.embed = model.embed #Embedding(model.embed, max_seq_len)
self.model = model
- from funasr.utils.torch_function import MakePadMask
+
from funasr.utils.torch_function import sequence_mask
self.model = model
- if onnx:
- self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
- else:
- self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
from funasr.models.transformer.decoder import DecoderLayerExport
--
Gitblit v1.9.1