| | |
| | | self.attn = None |
| | | self.all_head_size = self.h * self.d_k |
| | | |
| | | def forward(self, x, memory, memory_mask): |
| | | def forward(self, x, memory, memory_mask, ret_attn=False): |
| | | q, k, v = self.forward_qkv(x, memory) |
| | | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) |
| | | return self.forward_attention(v, scores, memory_mask) |
| | | return self.forward_attention(v, scores, memory_mask, ret_attn) |
| | | |
| | | def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: |
| | | new_x_shape = x.size()[:-1] + (self.h, self.d_k) |
| | |
| | | v = self.transpose_for_scores(v) |
| | | return q, k, v |
| | | |
| | | def forward_attention(self, value, scores, mask): |
| | | def forward_attention(self, value, scores, mask, ret_attn): |
| | | scores = scores + mask |
| | | |
| | | self.attn = torch.softmax(scores, dim=-1) |
| | |
| | | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| | | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
| | | context_layer = context_layer.view(new_context_layer_shape) |
| | | if ret_attn: return self.linear_out(context_layer), self.attn |
| | | return self.linear_out(context_layer) # (batch, time1, d_model) |
| | | |
| | | |
| | |
| | | scores = torch.matmul(q_h, k_h.transpose(-2, -1)) |
| | | att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) |
| | | return att_outs |
| | | |
| | | |
| | | |