From 4ba1011b42e041ee1d71448eefd7ef2e7bd61bb6 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 31 三月 2023 15:31:26 +0800
Subject: [PATCH] export
---
funasr/modules/attention.py | 22 ++++++++++++++++++----
1 files changed, 18 insertions(+), 4 deletions(-)
diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py
index c47d96d..31d5a87 100644
--- a/funasr/modules/attention.py
+++ b/funasr/modules/attention.py
@@ -347,15 +347,17 @@
mask = torch.reshape(mask, (b, -1, 1))
if mask_shfit_chunk is not None:
mask = mask * mask_shfit_chunk
+ inputs = inputs * mask
- inputs = inputs * mask
x = inputs.transpose(1, 2)
x = self.pad_fn(x)
x = self.fsmn_block(x)
x = x.transpose(1, 2)
x += inputs
x = self.dropout(x)
- return x * mask
+ if mask is not None:
+ x = x * mask
+ return x
def forward_qkv(self, x):
"""Transform query, key and value.
@@ -439,6 +441,18 @@
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
return att_outs + fsmn_memory
+class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+ q_h, k_h, v_h, v = self.forward_qkv(x)
+ fsmn_memory = self.forward_fsmn(v, mask[0], mask_shfit_chunk)
+ q_h = q_h * self.d_k ** (-0.5)
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+ att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder)
+ return att_outs + fsmn_memory
+
class MultiHeadedAttentionSANMDecoder(nn.Module):
"""Multi-Head Attention layer.
@@ -493,7 +507,7 @@
# print("in fsmn, cache is None, x", x.size())
x = self.pad_fn(x)
- if not self.training and t <= 1:
+ if not self.training:
cache = x
else:
# print("in fsmn, cache is not None, x", x.size())
@@ -501,7 +515,7 @@
# if t < self.kernel_size:
# x = self.pad_fn(x)
x = torch.cat((cache[:, :, 1:], x), dim=2)
- x = x[:, :, -self.kernel_size:]
+ x = x[:, :, -(self.kernel_size+t-1):]
# print("in fsmn, cache is not None, x_cat", x.size())
cache = x
x = self.fsmn_block(x)
--
Gitblit v1.9.1