From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/rwkv_bat/rwkv_attention.py | 68 +++++++++++-----------------------
1 files changed, 22 insertions(+), 46 deletions(-)
diff --git a/funasr/models/rwkv_bat/rwkv_attention.py b/funasr/models/rwkv_bat/rwkv_attention.py
index 5384fb9..59bf0ff 100644
--- a/funasr/models/rwkv_bat/rwkv_attention.py
+++ b/funasr/models/rwkv_bat/rwkv_attention.py
@@ -1,20 +1,18 @@
-"""Attention (time mixing) modules for RWKV block.
-
-Based/Modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py.
-
-Some variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py.
-
-""" # noqa
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
import math
-from importlib.util import find_spec
+import torch
from pathlib import Path
+from importlib.util import find_spec
from typing import List, Optional, Tuple, Union
-import torch
wkv_kernel_encoder = None
wkv_kernel_decoder = None
+
class WKVLinearAttentionEncoder(torch.autograd.Function):
"""WKVLinearAttention function definition."""
@@ -47,8 +45,7 @@
)
assert batch * dim % min(dim, 32) == 0, (
- f"batch size ({batch}) by dimension ({dim}) should be a multiple of "
- f"{min(dim, 32)}"
+ f"batch size ({batch}) by dimension ({dim}) should be a multiple of " f"{min(dim, 32)}"
)
ctx.input_dtype = key.dtype
@@ -127,6 +124,7 @@
grad_value,
)
+
class WKVLinearAttentionDecoder(torch.autograd.Function):
"""WKVLinearAttention function definition."""
@@ -158,8 +156,7 @@
)
assert batch * dim % min(dim, 32) == 0, (
- f"batch size ({batch}) by dimension ({dim}) should be a multiple of "
- f"{min(dim, 32)}"
+ f"batch size ({batch}) by dimension ({dim}) should be a multiple of " f"{min(dim, 32)}"
)
ctx.input_dtype = key.dtype
@@ -238,6 +235,7 @@
grad_value,
)
+
def load_encoder_wkv_kernel(context_size: int) -> None:
"""Load WKV CUDA kernel.
@@ -283,6 +281,7 @@
)
wkv_kernel_encoder.context_size = context_size
+
def load_decoder_wkv_kernel(context_size: int) -> None:
"""Load WKV CUDA kernel.
@@ -327,6 +326,7 @@
extra_cuda_cflags=kernel_cflags,
)
wkv_kernel_decoder.context_size = context_size
+
class SelfAttention(torch.nn.Module):
"""SelfAttention module definition.
@@ -409,17 +409,13 @@
with torch.no_grad():
self.time_decay.data = decay_speed
- self.time_first.data = torch.ones_like(
- self.time_first * math.log(0.3) + zigzag
- )
+ self.time_first.data = torch.ones_like(self.time_first * math.log(0.3) + zigzag)
self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
self.time_mix_value.data = (
torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
)
- self.time_mix_receptance.data = torch.pow(
- time_weight, 0.5 * ratio_1_to_almost0
- )
+ self.time_mix_receptance.data = torch.pow(time_weight, 0.5 * ratio_1_to_almost0)
@torch.no_grad()
def wkv_linear_attention(
@@ -488,13 +484,7 @@
num_blocks: int,
) -> None:
"""Construct a SelfAttention object."""
- super().__init__(
- size,
- attention_size,
- block_id,
- dropout_rate,
- num_blocks
- )
+ super().__init__(size, attention_size, block_id, dropout_rate, num_blocks)
# load_decoder_wkv_kernel(context_size)
def forward(
@@ -512,15 +502,11 @@
x: SelfAttention output sequences. (B, U, size)
"""
- shifted_x = (
- self.time_shift(x) if state is None else state[1][..., self.block_id]
- )
+ shifted_x = self.time_shift(x) if state is None else state[1][..., self.block_id]
key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value)
- receptance = x * self.time_mix_receptance + shifted_x * (
- 1 - self.time_mix_receptance
- )
+ receptance = x * self.time_mix_receptance + shifted_x * (1 - self.time_mix_receptance)
key = self.proj_key(key)
value = self.proj_value(value)
@@ -548,6 +534,7 @@
return x, state
+
class EncoderSelfAttention(SelfAttention):
"""SelfAttention module definition.
@@ -570,13 +557,7 @@
num_blocks: int,
) -> None:
"""Construct a SelfAttention object."""
- super().__init__(
- size,
- attention_size,
- block_id,
- dropout_rate,
- num_blocks
- )
+ super().__init__(size, attention_size, block_id, dropout_rate, num_blocks)
# load_encoder_wkv_kernel(context_size)
def forward(
@@ -594,15 +575,11 @@
x: SelfAttention output sequences. (B, U, size)
"""
- shifted_x = (
- self.time_shift(x) if state is None else state[1][..., self.block_id]
- )
+ shifted_x = self.time_shift(x) if state is None else state[1][..., self.block_id]
key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value)
- receptance = x * self.time_mix_receptance + shifted_x * (
- 1 - self.time_mix_receptance
- )
+ receptance = x * self.time_mix_receptance + shifted_x * (1 - self.time_mix_receptance)
key = self.proj_key(key)
value = self.proj_value(value)
@@ -629,4 +606,3 @@
x = self.proj_output(receptance * wkv)
return x, state
-
--
Gitblit v1.9.1