| | |
| | | wkv_kernel_encoder = None |
| | | wkv_kernel_decoder = None |
| | | |
| | | |
| | | class WKVLinearAttentionEncoder(torch.autograd.Function): |
| | | """WKVLinearAttention function definition.""" |
| | | |
| | |
| | | ) |
| | | |
| | | assert batch * dim % min(dim, 32) == 0, ( |
| | | f"batch size ({batch}) by dimension ({dim}) should be a multiple of " |
| | | f"{min(dim, 32)}" |
| | | f"batch size ({batch}) by dimension ({dim}) should be a multiple of " f"{min(dim, 32)}" |
| | | ) |
| | | |
| | | ctx.input_dtype = key.dtype |
| | |
| | | grad_value, |
| | | ) |
| | | |
| | | |
| | | class WKVLinearAttentionDecoder(torch.autograd.Function): |
| | | """WKVLinearAttention function definition.""" |
| | | |
| | |
| | | ) |
| | | |
| | | assert batch * dim % min(dim, 32) == 0, ( |
| | | f"batch size ({batch}) by dimension ({dim}) should be a multiple of " |
| | | f"{min(dim, 32)}" |
| | | f"batch size ({batch}) by dimension ({dim}) should be a multiple of " f"{min(dim, 32)}" |
| | | ) |
| | | |
| | | ctx.input_dtype = key.dtype |
| | |
| | | grad_value, |
| | | ) |
| | | |
| | | |
| | | def load_encoder_wkv_kernel(context_size: int) -> None: |
| | | """Load WKV CUDA kernel. |
| | | |
| | |
| | | ) |
| | | wkv_kernel_encoder.context_size = context_size |
| | | |
| | | |
| | | def load_decoder_wkv_kernel(context_size: int) -> None: |
| | | """Load WKV CUDA kernel. |
| | | |
| | |
| | | extra_cuda_cflags=kernel_cflags, |
| | | ) |
| | | wkv_kernel_decoder.context_size = context_size |
| | | |
| | | |
| | | class SelfAttention(torch.nn.Module): |
| | | """SelfAttention module definition. |
| | |
| | | |
| | | with torch.no_grad(): |
| | | self.time_decay.data = decay_speed |
| | | self.time_first.data = torch.ones_like( |
| | | self.time_first * math.log(0.3) + zigzag |
| | | ) |
| | | self.time_first.data = torch.ones_like(self.time_first * math.log(0.3) + zigzag) |
| | | |
| | | self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) |
| | | self.time_mix_value.data = ( |
| | | torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 |
| | | ) |
| | | self.time_mix_receptance.data = torch.pow( |
| | | time_weight, 0.5 * ratio_1_to_almost0 |
| | | ) |
| | | self.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0) |
| | | |
| | | @torch.no_grad() |
| | | def wkv_linear_attention( |
| | |
| | | num_blocks: int, |
| | | ) -> None: |
| | | """Construct a SelfAttention object.""" |
| | | super().__init__( |
| | | size, |
| | | attention_size, |
| | | block_id, |
| | | dropout_rate, |
| | | num_blocks |
| | | ) |
| | | super().__init__(size, attention_size, block_id, dropout_rate, num_blocks) |
| | | # load_decoder_wkv_kernel(context_size) |
| | | |
| | | def forward( |
| | |
| | | x: SelfAttention output sequences. (B, U, size) |
| | | |
| | | """ |
| | | shifted_x = ( |
| | | self.time_shift(x) if state is None else state[1][..., self.block_id] |
| | | ) |
| | | shifted_x = self.time_shift(x) if state is None else state[1][..., self.block_id] |
| | | |
| | | key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key) |
| | | value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value) |
| | | receptance = x * self.time_mix_receptance + shifted_x * ( |
| | | 1 - self.time_mix_receptance |
| | | ) |
| | | receptance = x * self.time_mix_receptance + shifted_x * (1 - self.time_mix_receptance) |
| | | |
| | | key = self.proj_key(key) |
| | | value = self.proj_value(value) |
| | |
| | | |
| | | return x, state |
| | | |
| | | |
| | | class EncoderSelfAttention(SelfAttention): |
| | | """SelfAttention module definition. |
| | | |
| | |
| | | num_blocks: int, |
| | | ) -> None: |
| | | """Construct a SelfAttention object.""" |
| | | super().__init__( |
| | | size, |
| | | attention_size, |
| | | block_id, |
| | | dropout_rate, |
| | | num_blocks |
| | | ) |
| | | super().__init__(size, attention_size, block_id, dropout_rate, num_blocks) |
| | | # load_encoder_wkv_kernel(context_size) |
| | | |
| | | def forward( |
| | |
| | | x: SelfAttention output sequences. (B, U, size) |
| | | |
| | | """ |
| | | shifted_x = ( |
| | | self.time_shift(x) if state is None else state[1][..., self.block_id] |
| | | ) |
| | | shifted_x = self.time_shift(x) if state is None else state[1][..., self.block_id] |
| | | |
| | | key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key) |
| | | value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value) |
| | | receptance = x * self.time_mix_receptance + shifted_x * ( |
| | | 1 - self.time_mix_receptance |
| | | ) |
| | | receptance = x * self.time_mix_receptance + shifted_x * (1 - self.time_mix_receptance) |
| | | |
| | | key = self.proj_key(key) |
| | | value = self.proj_value(value) |
| | |
| | | x = self.proj_output(receptance * wkv) |
| | | |
| | | return x, state |
| | | |