| | |
| | | xa: Optional[Tensor] = None, |
| | | mask: Optional[Tensor] = None, |
| | | kv_cache: Optional[dict] = None, |
| | | **kwargs, |
| | | ): |
| | | is_pad_mask = kwargs.get("is_pad_mask", False) |
| | | |
| | | q = self.query(x) |
| | | |
| | | if kv_cache is None or xa is None or self.key not in kv_cache: |
| | |
| | | k = kv_cache[self.key] |
| | | v = kv_cache[self.value] |
| | | |
| | | wv, qk = self.qkv_attention(q, k, v, mask) |
| | | wv, qk = self.qkv_attention(q, k, v, mask, is_pad_mask=is_pad_mask) |
| | | return self.out(wv), qk |
| | | |
| | | def qkv_attention( |
| | | self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None |
| | | self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, **kwargs, |
| | | ): |
| | | is_pad_mask = kwargs.get("is_pad_mask", False) |
| | | n_batch, n_ctx, n_state = q.shape |
| | | scale = (n_state // self.n_head) ** -0.25 |
| | | q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale |
| | |
| | | |
| | | qk = q @ k |
| | | if mask is not None: |
| | | qk = qk + mask[:n_ctx, :n_ctx] |
| | | if not is_pad_mask: |
| | | qk = qk + mask[:n_ctx, :n_ctx] |
| | | else: |
| | | mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) |
| | | min_value = float( |
| | | np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min |
| | | ) |
| | | qk = qk.masked_fill(mask, min_value) |
| | | |
| | | qk = qk.float() |
| | | |
| | | w = F.softmax(qk, dim=-1).to(q.dtype) |
| | | if mask is not None and is_pad_mask: |
| | | w = w.masked_fill(mask, 0.0) |
| | | return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() |
| | | |
| | | |
| | |
| | | xa: Optional[Tensor] = None, |
| | | mask: Optional[Tensor] = None, |
| | | kv_cache: Optional[dict] = None, |
| | | **kwargs, |
| | | ): |
| | | x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] |
| | | is_pad_mask = kwargs.get("is_pad_mask", False) |
| | | is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False) |
| | | x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0] |
| | | if self.cross_attn: |
| | | x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] |
| | | x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask)[0] |
| | | x = x + self.mlp(self.mlp_ln(x)) |
| | | return x |
| | | |