From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/rwkv_bat/rwkv_attention.py |   53 ++++++++++++++++-------------------------------------
 1 files changed, 16 insertions(+), 37 deletions(-)

diff --git a/funasr/models/rwkv_bat/rwkv_attention.py b/funasr/models/rwkv_bat/rwkv_attention.py
index c085874..59bf0ff 100644
--- a/funasr/models/rwkv_bat/rwkv_attention.py
+++ b/funasr/models/rwkv_bat/rwkv_attention.py
@@ -13,6 +13,7 @@
 wkv_kernel_encoder = None
 wkv_kernel_decoder = None
 
+
 class WKVLinearAttentionEncoder(torch.autograd.Function):
     """WKVLinearAttention function definition."""
 
@@ -44,8 +45,7 @@
         )
 
         assert batch * dim % min(dim, 32) == 0, (
-            f"batch size ({batch}) by dimension ({dim}) should be a multiple of "
-            f"{min(dim, 32)}"
+            f"batch size ({batch}) by dimension ({dim}) should be a multiple of " f"{min(dim, 32)}"
         )
 
         ctx.input_dtype = key.dtype
@@ -124,6 +124,7 @@
             grad_value,
         )
 
+
 class WKVLinearAttentionDecoder(torch.autograd.Function):
     """WKVLinearAttention function definition."""
 
@@ -155,8 +156,7 @@
         )
 
         assert batch * dim % min(dim, 32) == 0, (
-            f"batch size ({batch}) by dimension ({dim}) should be a multiple of "
-            f"{min(dim, 32)}"
+            f"batch size ({batch}) by dimension ({dim}) should be a multiple of " f"{min(dim, 32)}"
         )
 
         ctx.input_dtype = key.dtype
@@ -235,6 +235,7 @@
             grad_value,
         )
 
+
 def load_encoder_wkv_kernel(context_size: int) -> None:
     """Load WKV CUDA kernel.
 
@@ -280,6 +281,7 @@
     )
     wkv_kernel_encoder.context_size = context_size
 
+
 def load_decoder_wkv_kernel(context_size: int) -> None:
     """Load WKV CUDA kernel.
 
@@ -324,6 +326,7 @@
         extra_cuda_cflags=kernel_cflags,
     )
     wkv_kernel_decoder.context_size = context_size
+
 
 class SelfAttention(torch.nn.Module):
     """SelfAttention module definition.
@@ -406,17 +409,13 @@
 
         with torch.no_grad():
             self.time_decay.data = decay_speed
-            self.time_first.data = torch.ones_like(
-                self.time_first * math.log(0.3) + zigzag
-            )
+            self.time_first.data = torch.ones_like(self.time_first * math.log(0.3) + zigzag)
 
             self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
             self.time_mix_value.data = (
                 torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
             )
-            self.time_mix_receptance.data = torch.pow(
-                time_weight, 0.5 * ratio_1_to_almost0
-            )
+            self.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
 
     @torch.no_grad()
     def wkv_linear_attention(
@@ -485,13 +484,7 @@
         num_blocks: int,
     ) -> None:
         """Construct a SelfAttention object."""
-        super().__init__(
-            size,
-            attention_size,
-            block_id,
-            dropout_rate,
-            num_blocks
-        )
+        super().__init__(size, attention_size, block_id, dropout_rate, num_blocks)
         # load_decoder_wkv_kernel(context_size)
 
     def forward(
@@ -509,15 +502,11 @@
             x: SelfAttention output sequences. (B, U, size)
 
         """
-        shifted_x = (
-            self.time_shift(x) if state is None else state[1][..., self.block_id]
-        )
+        shifted_x = self.time_shift(x) if state is None else state[1][..., self.block_id]
 
         key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
         value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value)
-        receptance = x * self.time_mix_receptance + shifted_x * (
-            1 - self.time_mix_receptance
-        )
+        receptance = x * self.time_mix_receptance + shifted_x * (1 - self.time_mix_receptance)
 
         key = self.proj_key(key)
         value = self.proj_value(value)
@@ -545,6 +534,7 @@
 
         return x, state
 
+
 class EncoderSelfAttention(SelfAttention):
     """SelfAttention module definition.
 
@@ -567,13 +557,7 @@
         num_blocks: int,
     ) -> None:
         """Construct a SelfAttention object."""
-        super().__init__(
-            size,
-            attention_size,
-            block_id,
-            dropout_rate,
-            num_blocks
-        )
+        super().__init__(size, attention_size, block_id, dropout_rate, num_blocks)
         # load_encoder_wkv_kernel(context_size)
 
     def forward(
@@ -591,15 +575,11 @@
             x: SelfAttention output sequences. (B, U, size)
 
         """
-        shifted_x = (
-            self.time_shift(x) if state is None else state[1][..., self.block_id]
-        )
+        shifted_x = self.time_shift(x) if state is None else state[1][..., self.block_id]
 
         key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
         value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value)
-        receptance = x * self.time_mix_receptance + shifted_x * (
-            1 - self.time_mix_receptance
-        )
+        receptance = x * self.time_mix_receptance + shifted_x * (1 - self.time_mix_receptance)
 
         key = self.proj_key(key)
         value = self.proj_value(value)
@@ -626,4 +606,3 @@
         x = self.proj_output(receptance * wkv)
 
         return x, state
-

--
Gitblit v1.9.1