| | |
| | | def forward_attention(self, value, scores, mask): |
| | | scores = scores + mask |
| | | |
| | | self.attn = torch.softmax(scores, dim=-1) |
| | | context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k) |
| | | attn = torch.softmax(scores, dim=-1) |
| | | context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k) |
| | | |
| | | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| | | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
| | |
| | | def forward_attention(self, value, scores, mask): |
| | | scores = scores + mask |
| | | |
| | | self.attn = torch.softmax(scores, dim=-1) |
| | | context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k) |
| | | attn = torch.softmax(scores, dim=-1) |
| | | context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k) |
| | | |
| | | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| | | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
| | |
| | | def forward_attention(self, value, scores, mask): |
| | | scores = scores + mask |
| | | |
| | | self.attn = torch.softmax(scores, dim=-1) |
| | | context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k) |
| | | attn = torch.softmax(scores, dim=-1) |
| | | context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k) |
| | | |
| | | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| | | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
| | |
| | | def forward_attention(self, value, scores, mask): |
| | | scores = scores + mask |
| | | |
| | | self.attn = torch.softmax(scores, dim=-1) |
| | | context_layer = torch.matmul(self.attn, value) # (batch, head, time1, d_k) |
| | | attn = torch.softmax(scores, dim=-1) |
| | | context_layer = torch.matmul(attn, value) # (batch, head, time1, d_k) |
| | | |
| | | context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
| | | new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |