| | |
| | | |
| | | # (batch, n_heads, time) |
| | | query_for_score = ( |
| | | self.query_att(mixed_query_layer).transpose(1, 2) |
| | | / self.attention_head_size**0.5 |
| | | self.query_att(mixed_query_layer).transpose(1, 2) / self.attention_head_size**0.5 |
| | | ) |
| | | if mask is not None: |
| | | min_value = float( |
| | | numpy.finfo( |
| | | torch.tensor(0, dtype=query_for_score.dtype).numpy().dtype |
| | | ).min |
| | | numpy.finfo(torch.tensor(0, dtype=query_for_score.dtype).numpy().dtype).min |
| | | ) |
| | | query_for_score = query_for_score.masked_fill(mask, min_value) |
| | | query_weight = torch.softmax(query_for_score, dim=-1).masked_fill(mask, 0.0) |
| | |
| | | pooled_query = self.dropout(pooled_query) |
| | | pooled_query_repeat = pooled_query.repeat(1, seq_len, 1) # (batch, time, size) |
| | | |
| | | mixed_query_key_layer = ( |
| | | mixed_key_layer * pooled_query_repeat |
| | | ) # (batch, time, size) |
| | | mixed_query_key_layer = mixed_key_layer * pooled_query_repeat # (batch, time, size) |
| | | |
| | | # (batch, n_heads, time) |
| | | query_key_score = ( |
| | |
| | | ).transpose(1, 2) |
| | | if mask is not None: |
| | | min_value = float( |
| | | numpy.finfo( |
| | | torch.tensor(0, dtype=query_key_score.dtype).numpy().dtype |
| | | ).min |
| | | numpy.finfo(torch.tensor(0, dtype=query_key_score.dtype).numpy().dtype).min |
| | | ) |
| | | query_key_score = query_key_score.masked_fill(mask, min_value) |
| | | query_key_weight = torch.softmax(query_key_score, dim=-1).masked_fill( |
| | | mask, 0.0 |
| | | ) |
| | | query_key_weight = torch.softmax(query_key_score, dim=-1).masked_fill(mask, 0.0) |
| | | else: |
| | | query_key_weight = torch.softmax(query_key_score, dim=-1) |
| | | |
| | |
| | | key_layer = self.transpose_for_scores( |
| | | mixed_query_key_layer |
| | | ) # (batch, n_heads, time, attn_dim) |
| | | pooled_key = torch.matmul( |
| | | query_key_weight, key_layer |
| | | ) # (batch, n_heads, 1, attn_dim) |
| | | pooled_key = torch.matmul(query_key_weight, key_layer) # (batch, n_heads, 1, attn_dim) |
| | | pooled_key = self.dropout(pooled_key) |
| | | |
| | | # NOTE: value = query, due to param sharing |
| | |
| | | 1, 2 |
| | | ) # (batch, time, n_heads, attn_dim) |
| | | weighted_value = weighted_value.reshape( |
| | | weighted_value.shape[:-2] |
| | | + (self.num_attention_heads * self.attention_head_size,) |
| | | weighted_value.shape[:-2] + (self.num_attention_heads * self.attention_head_size,) |
| | | ) # (batch, time, size) |
| | | weighted_value = ( |
| | | self.dropout(self.transform(weighted_value)) + mixed_query_layer |
| | | ) |
| | | weighted_value = self.dropout(self.transform(weighted_value)) + mixed_query_layer |
| | | |
| | | return weighted_value |