From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/rwkv_bat/rwkv_attention.py | 53 ++++++++++++++++-------------------------------------
1 files changed, 16 insertions(+), 37 deletions(-)
diff --git a/funasr/models/rwkv_bat/rwkv_attention.py b/funasr/models/rwkv_bat/rwkv_attention.py
index c085874..59bf0ff 100644
--- a/funasr/models/rwkv_bat/rwkv_attention.py
+++ b/funasr/models/rwkv_bat/rwkv_attention.py
@@ -13,6 +13,7 @@
wkv_kernel_encoder = None
wkv_kernel_decoder = None
+
class WKVLinearAttentionEncoder(torch.autograd.Function):
"""WKVLinearAttention function definition."""
@@ -44,8 +45,7 @@
)
assert batch * dim % min(dim, 32) == 0, (
- f"batch size ({batch}) by dimension ({dim}) should be a multiple of "
- f"{min(dim, 32)}"
+ f"batch size ({batch}) by dimension ({dim}) should be a multiple of " f"{min(dim, 32)}"
)
ctx.input_dtype = key.dtype
@@ -124,6 +124,7 @@
grad_value,
)
+
class WKVLinearAttentionDecoder(torch.autograd.Function):
"""WKVLinearAttention function definition."""
@@ -155,8 +156,7 @@
)
assert batch * dim % min(dim, 32) == 0, (
- f"batch size ({batch}) by dimension ({dim}) should be a multiple of "
- f"{min(dim, 32)}"
+ f"batch size ({batch}) by dimension ({dim}) should be a multiple of " f"{min(dim, 32)}"
)
ctx.input_dtype = key.dtype
@@ -235,6 +235,7 @@
grad_value,
)
+
def load_encoder_wkv_kernel(context_size: int) -> None:
"""Load WKV CUDA kernel.
@@ -280,6 +281,7 @@
)
wkv_kernel_encoder.context_size = context_size
+
def load_decoder_wkv_kernel(context_size: int) -> None:
"""Load WKV CUDA kernel.
@@ -324,6 +326,7 @@
extra_cuda_cflags=kernel_cflags,
)
wkv_kernel_decoder.context_size = context_size
+
class SelfAttention(torch.nn.Module):
"""SelfAttention module definition.
@@ -406,17 +409,13 @@
with torch.no_grad():
self.time_decay.data = decay_speed
- self.time_first.data = torch.ones_like(
- self.time_first * math.log(0.3) + zigzag
- )
+ self.time_first.data = torch.ones_like(self.time_first * math.log(0.3) + zigzag)
self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
self.time_mix_value.data = (
torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
)
- self.time_mix_receptance.data = torch.pow(
- time_weight, 0.5 * ratio_1_to_almost0
- )
+ self.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
@torch.no_grad()
def wkv_linear_attention(
@@ -485,13 +484,7 @@
num_blocks: int,
) -> None:
"""Construct a SelfAttention object."""
- super().__init__(
- size,
- attention_size,
- block_id,
- dropout_rate,
- num_blocks
- )
+ super().__init__(size, attention_size, block_id, dropout_rate, num_blocks)
# load_decoder_wkv_kernel(context_size)
def forward(
@@ -509,15 +502,11 @@
x: SelfAttention output sequences. (B, U, size)
"""
- shifted_x = (
- self.time_shift(x) if state is None else state[1][..., self.block_id]
- )
+ shifted_x = self.time_shift(x) if state is None else state[1][..., self.block_id]
key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value)
- receptance = x * self.time_mix_receptance + shifted_x * (
- 1 - self.time_mix_receptance
- )
+ receptance = x * self.time_mix_receptance + shifted_x * (1 - self.time_mix_receptance)
key = self.proj_key(key)
value = self.proj_value(value)
@@ -545,6 +534,7 @@
return x, state
+
class EncoderSelfAttention(SelfAttention):
"""SelfAttention module definition.
@@ -567,13 +557,7 @@
num_blocks: int,
) -> None:
"""Construct a SelfAttention object."""
- super().__init__(
- size,
- attention_size,
- block_id,
- dropout_rate,
- num_blocks
- )
+ super().__init__(size, attention_size, block_id, dropout_rate, num_blocks)
# load_encoder_wkv_kernel(context_size)
def forward(
@@ -591,15 +575,11 @@
x: SelfAttention output sequences. (B, U, size)
"""
- shifted_x = (
- self.time_shift(x) if state is None else state[1][..., self.block_id]
- )
+ shifted_x = self.time_shift(x) if state is None else state[1][..., self.block_id]
key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value)
- receptance = x * self.time_mix_receptance + shifted_x * (
- 1 - self.time_mix_receptance
- )
+ receptance = x * self.time_mix_receptance + shifted_x * (1 - self.time_mix_receptance)
key = self.proj_key(key)
value = self.proj_value(value)
@@ -626,4 +606,3 @@
x = self.proj_output(receptance * wkv)
return x, state
-
--
Gitblit v1.9.1