| | |
| | | |
| | | return q_h, k_h, v_h |
| | | |
| | | def forward_attention(self, value, scores, mask): |
| | | def forward_attention(self, value, scores, mask, ret_attn=False): |
| | | """Compute attention context vector. |
| | | |
| | | Args: |
| | |
| | | ) # (batch, head, time1, time2) |
| | | else: |
| | | self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) |
| | | |
| | | p_attn = self.dropout(self.attn) |
| | | x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) |
| | | x = ( |
| | | x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) |
| | | ) # (batch, time1, d_model) |
| | | |
| | | if ret_attn: |
| | | return self.linear_out(x), self.attn # (batch, time1, d_model) |
| | | return self.linear_out(x) # (batch, time1, d_model) |
| | | |
| | | def forward(self, x, memory, memory_mask): |
| | | def forward(self, x, memory, memory_mask, ret_attn=False): |
| | | """Compute scaled dot product attention. |
| | | |
| | | Args: |
| | |
| | | q_h, k_h, v_h = self.forward_qkv(x, memory) |
| | | q_h = q_h * self.d_k ** (-0.5) |
| | | scores = torch.matmul(q_h, k_h.transpose(-2, -1)) |
| | | return self.forward_attention(v_h, scores, memory_mask) |
| | | return self.forward_attention(v_h, scores, memory_mask, ret_attn=ret_attn) |
| | | |
| | | def forward_chunk(self, x, memory, cache=None, chunk_size=None, look_back=0): |
| | | """Compute scaled dot product attention. |