游雁
2024-06-24 1596f6f414f6f41da66506debb1dff19fffeb3ec
funasr/models/rwkv_bat/rwkv_attention.py
@@ -13,6 +13,7 @@
wkv_kernel_encoder = None
wkv_kernel_decoder = None
class WKVLinearAttentionEncoder(torch.autograd.Function):
    """WKVLinearAttention function definition."""
@@ -44,8 +45,7 @@
        )
        assert batch * dim % min(dim, 32) == 0, (
            f"batch size ({batch}) by dimension ({dim}) should be a multiple of "
            f"{min(dim, 32)}"
            f"batch size ({batch}) by dimension ({dim}) should be a multiple of " f"{min(dim, 32)}"
        )
        ctx.input_dtype = key.dtype
@@ -124,6 +124,7 @@
            grad_value,
        )
class WKVLinearAttentionDecoder(torch.autograd.Function):
    """WKVLinearAttention function definition."""
@@ -155,8 +156,7 @@
        )
        assert batch * dim % min(dim, 32) == 0, (
            f"batch size ({batch}) by dimension ({dim}) should be a multiple of "
            f"{min(dim, 32)}"
            f"batch size ({batch}) by dimension ({dim}) should be a multiple of " f"{min(dim, 32)}"
        )
        ctx.input_dtype = key.dtype
@@ -235,6 +235,7 @@
            grad_value,
        )
def load_encoder_wkv_kernel(context_size: int) -> None:
    """Load WKV CUDA kernel.
@@ -280,6 +281,7 @@
    )
    wkv_kernel_encoder.context_size = context_size
def load_decoder_wkv_kernel(context_size: int) -> None:
    """Load WKV CUDA kernel.
@@ -324,6 +326,7 @@
        extra_cuda_cflags=kernel_cflags,
    )
    wkv_kernel_decoder.context_size = context_size
class SelfAttention(torch.nn.Module):
    """SelfAttention module definition.
@@ -406,17 +409,13 @@
        with torch.no_grad():
            self.time_decay.data = decay_speed
            self.time_first.data = torch.ones_like(
                self.time_first * math.log(0.3) + zigzag
            )
            self.time_first.data = torch.ones_like(self.time_first * math.log(0.3) + zigzag)
            self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
            self.time_mix_value.data = (
                torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
            )
            self.time_mix_receptance.data = torch.pow(
                time_weight, 0.5 * ratio_1_to_almost0
            )
            self.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
    @torch.no_grad()
    def wkv_linear_attention(
@@ -485,13 +484,7 @@
        num_blocks: int,
    ) -> None:
        """Construct a SelfAttention object."""
        super().__init__(
            size,
            attention_size,
            block_id,
            dropout_rate,
            num_blocks
        )
        super().__init__(size, attention_size, block_id, dropout_rate, num_blocks)
        # load_decoder_wkv_kernel(context_size)
    def forward(
@@ -509,15 +502,11 @@
            x: SelfAttention output sequences. (B, U, size)
        """
        shifted_x = (
            self.time_shift(x) if state is None else state[1][..., self.block_id]
        )
        shifted_x = self.time_shift(x) if state is None else state[1][..., self.block_id]
        key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
        value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value)
        receptance = x * self.time_mix_receptance + shifted_x * (
            1 - self.time_mix_receptance
        )
        receptance = x * self.time_mix_receptance + shifted_x * (1 - self.time_mix_receptance)
        key = self.proj_key(key)
        value = self.proj_value(value)
@@ -545,6 +534,7 @@
        return x, state
class EncoderSelfAttention(SelfAttention):
    """SelfAttention module definition.
@@ -567,13 +557,7 @@
        num_blocks: int,
    ) -> None:
        """Construct a SelfAttention object."""
        super().__init__(
            size,
            attention_size,
            block_id,
            dropout_rate,
            num_blocks
        )
        super().__init__(size, attention_size, block_id, dropout_rate, num_blocks)
        # load_encoder_wkv_kernel(context_size)
    def forward(
@@ -591,15 +575,11 @@
            x: SelfAttention output sequences. (B, U, size)
        """
        shifted_x = (
            self.time_shift(x) if state is None else state[1][..., self.block_id]
        )
        shifted_x = self.time_shift(x) if state is None else state[1][..., self.block_id]
        key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
        value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value)
        receptance = x * self.time_mix_receptance + shifted_x * (
            1 - self.time_mix_receptance
        )
        receptance = x * self.time_mix_receptance + shifted_x * (1 - self.time_mix_receptance)
        key = self.proj_key(key)
        value = self.proj_value(value)
@@ -626,4 +606,3 @@
        x = self.proj_output(receptance * wkv)
        return x, state