From 4ba1011b42e041ee1d71448eefd7ef2e7bd61bb6 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 31 三月 2023 15:31:26 +0800
Subject: [PATCH] export
---
funasr/modules/attention.py | 128 +++++++++++++++++++++++++++++++++++++++++-
1 files changed, 123 insertions(+), 5 deletions(-)
diff --git a/funasr/modules/attention.py b/funasr/modules/attention.py
index e3ad56a..31d5a87 100644
--- a/funasr/modules/attention.py
+++ b/funasr/modules/attention.py
@@ -347,15 +347,17 @@
mask = torch.reshape(mask, (b, -1, 1))
if mask_shfit_chunk is not None:
mask = mask * mask_shfit_chunk
+ inputs = inputs * mask
- inputs = inputs * mask
x = inputs.transpose(1, 2)
x = self.pad_fn(x)
x = self.fsmn_block(x)
x = x.transpose(1, 2)
x += inputs
x = self.dropout(x)
- return x * mask
+ if mask is not None:
+ x = x * mask
+ return x
def forward_qkv(self, x):
"""Transform query, key and value.
@@ -439,6 +441,18 @@
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
return att_outs + fsmn_memory
+class MultiHeadedAttentionSANMwithMask(MultiHeadedAttentionSANM):
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+
+ def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+ q_h, k_h, v_h, v = self.forward_qkv(x)
+ fsmn_memory = self.forward_fsmn(v, mask[0], mask_shfit_chunk)
+ q_h = q_h * self.d_k ** (-0.5)
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+ att_outs = self.forward_attention(v_h, scores, mask[1], mask_att_chunk_encoder)
+ return att_outs + fsmn_memory
+
class MultiHeadedAttentionSANMDecoder(nn.Module):
"""Multi-Head Attention layer.
@@ -493,7 +507,7 @@
# print("in fsmn, cache is None, x", x.size())
x = self.pad_fn(x)
- if not self.training and t <= 1:
+ if not self.training:
cache = x
else:
# print("in fsmn, cache is not None, x", x.size())
@@ -501,7 +515,7 @@
# if t < self.kernel_size:
# x = self.pad_fn(x)
x = torch.cat((cache[:, :, 1:], x), dim=2)
- x = x[:, :, -self.kernel_size:]
+ x = x[:, :, -(self.kernel_size+t-1):]
# print("in fsmn, cache is not None, x_cat", x.size())
cache = x
x = self.fsmn_block(x)
@@ -622,4 +636,108 @@
q_h, k_h, v_h = self.forward_qkv(x, memory)
q_h = q_h * self.d_k ** (-0.5)
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
- return self.forward_attention(v_h, scores, memory_mask)
\ No newline at end of file
+ return self.forward_attention(v_h, scores, memory_mask)
+
+
+class MultiHeadSelfAttention(nn.Module):
+ """Multi-Head Attention layer.
+
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, n_head, in_feat, n_feat, dropout_rate):
+ """Construct an MultiHeadedAttention object."""
+ super(MultiHeadSelfAttention, self).__init__()
+ assert n_feat % n_head == 0
+ # We assume d_v always equals d_k
+ self.d_k = n_feat // n_head
+ self.h = n_head
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
+ self.attn = None
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def forward_qkv(self, x):
+ """Transform query, key and value.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+
+ Returns:
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+ """
+ b, t, d = x.size()
+ q_k_v = self.linear_q_k_v(x)
+ q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
+ q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time1, d_k)
+ k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
+ v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(1, 2) # (batch, head, time2, d_k)
+
+ return q_h, k_h, v_h, v
+
+ def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
+ """Compute attention context vector.
+
+ Args:
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+
+ """
+ n_batch = value.size(0)
+ if mask is not None:
+ if mask_att_chunk_encoder is not None:
+ mask = mask * mask_att_chunk_encoder
+
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+
+ min_value = float(
+ numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min
+ )
+ scores = scores.masked_fill(mask, min_value)
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ mask, 0.0
+ ) # (batch, head, time1, time2)
+ else:
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+
+ p_attn = self.dropout(self.attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+ x = (
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+ ) # (batch, time1, d_model)
+
+ return self.linear_out(x) # (batch, time1, d_model)
+
+ def forward(self, x, mask, mask_att_chunk_encoder=None):
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+
+ """
+ q_h, k_h, v_h, v = self.forward_qkv(x)
+ q_h = q_h * self.d_k ** (-0.5)
+ scores = torch.matmul(q_h, k_h.transpose(-2, -1))
+ att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
+ return att_outs
--
Gitblit v1.9.1