| | |
| | | xa: Optional[Tensor] = None, |
| | | mask: Optional[Tensor] = None, |
| | | kv_cache: Optional[dict] = None, |
| | | **kwargs, |
| | | ): |
| | | is_pad_mask = kwargs.get("is_pad_mask", False) |
| | | |
| | | q = self.query(x) |
| | | |
| | | if kv_cache is None or xa is None or self.key not in kv_cache: |
| | |
| | | k = kv_cache[self.key] |
| | | v = kv_cache[self.value] |
| | | |
| | | wv, qk = self.qkv_attention(q, k, v, mask) |
| | | wv, qk = self.qkv_attention(q, k, v, mask, is_pad_mask=is_pad_mask) |
| | | return self.out(wv), qk |
| | | |
| | | def qkv_attention( |
| | | self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None |
| | | self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, **kwargs, |
| | | ): |
| | | is_pad_mask = kwargs.get("is_pad_mask", False) |
| | | n_batch, n_ctx, n_state = q.shape |
| | | scale = (n_state // self.n_head) ** -0.25 |
| | | q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale |
| | |
| | | |
| | | qk = q @ k |
| | | if mask is not None: |
| | | qk = qk + mask[:n_ctx, :n_ctx] |
| | | if not is_pad_mask: |
| | | qk = qk + mask[:n_ctx, :n_ctx] |
| | | else: |
| | | mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) |
| | | min_value = float( |
| | | np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min |
| | | ) |
| | | qk = qk.masked_fill(mask, min_value) |
| | | |
| | | qk = qk.float() |
| | | |
| | | w = F.softmax(qk, dim=-1).to(q.dtype) |
| | | if mask is not None and is_pad_mask: |
| | | w = w.masked_fill(mask, 0.0) |
| | | return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() |
| | | |
| | | |
| | |
| | | xa: Optional[Tensor] = None, |
| | | mask: Optional[Tensor] = None, |
| | | kv_cache: Optional[dict] = None, |
| | | **kwargs, |
| | | ): |
| | | x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0] |
| | | is_pad_mask = kwargs.get("is_pad_mask", False) |
| | | is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False) |
| | | x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0] |
| | | if self.cross_attn: |
| | | x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0] |
| | | x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask)[0] |
| | | x = x + self.mlp(self.mlp_ln(x)) |
| | | return x |
| | | |
| | |
| | | self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int |
| | | ): |
| | | super().__init__() |
| | | self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1) |
| | | self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, stride=2, padding=1) |
| | | self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1) |
| | | self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state)) |
| | | |
| | |
| | | x = F.gelu(self.conv2(x)) |
| | | x = x.permute(0, 2, 1) |
| | | |
| | | assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" |
| | | x = (x + self.positional_embedding).to(x.dtype) |
| | | # assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape" |
| | | # x = (x + self.positional_embedding).to(x.dtype) |
| | | x = (x + self.positional_embedding[: x.size(1), :]).to(x.dtype) |
| | | |
| | | |
| | | for block in self.blocks: |
| | | x = block(x) |
| | |
| | | self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool |
| | | ) |
| | | all_heads[self.dims.n_text_layer // 2 :] = True |
| | | self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) |
| | | # self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False) |
| | | # alignment_heads_dense = model.get_buffer("alignment_heads").to_dense() |
| | | # model.register_buffer("alignment_heads", alignment_heads_dense, persistent=False) |
| | | |
| | | def set_alignment_heads(self, dump: bytes): |
| | | array = np.frombuffer( |
| | |
| | | |
| | | detect_language = detect_language_function |
| | | transcribe = transcribe_function |
| | | decode = decode_function |
| | | decode = decode_function |