From ab653d3871f72f7f6cd1ac3126b3df722f4c7943 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期五, 20 十月 2023 15:33:09 +0800
Subject: [PATCH] add rwkv encoder
---
funasr/build_utils/build_asr_model.py | 2
funasr/modules/rwkv.py | 145 ++++++
funasr/modules/rwkv_feed_forward.py | 97 ++++
funasr/modules/rwkv_attention.py | 632 ++++++++++++++++++++++++++++++
funasr/modules/rwkv_subsampling.py | 190 +++++++++
funasr/models/encoder/rwkv_encoder.py | 158 +++++++
6 files changed, 1,224 insertions(+), 0 deletions(-)
diff --git a/funasr/build_utils/build_asr_model.py b/funasr/build_utils/build_asr_model.py
index 5e93444..fd47bd3 100644
--- a/funasr/build_utils/build_asr_model.py
+++ b/funasr/build_utils/build_asr_model.py
@@ -42,6 +42,7 @@
from funasr.models.encoder.branchformer_encoder import BranchformerEncoder
from funasr.models.encoder.e_branchformer_encoder import EBranchformerEncoder
from funasr.models.encoder.transformer_encoder import TransformerEncoder
+from funasr.models.encoder.rwkv_encoder import RWKVEncoder
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.default import MultiChannelFrontend
from funasr.models.frontend.fused import FusedFrontends
@@ -119,6 +120,7 @@
e_branchformer=EBranchformerEncoder,
mfcca_enc=MFCCAEncoder,
chunk_conformer=ConformerChunkEncoder,
+ rwkv=RWKVEncoder,
),
default="rnn",
)
diff --git a/funasr/models/encoder/rwkv_encoder.py b/funasr/models/encoder/rwkv_encoder.py
new file mode 100644
index 0000000..291ed19
--- /dev/null
+++ b/funasr/models/encoder/rwkv_encoder.py
@@ -0,0 +1,158 @@
+"""RWKV encoder definition for Transducer models."""
+
+import math
+from typing import Dict, List, Optional, Tuple
+
+import torch
+from typeguard import check_argument_types
+
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.modules.rwkv import RWKV
+from funasr.modules.layer_norm import LayerNorm
+from funasr.modules.rwkv_subsampling import RWKVConvInput
+from funasr.modules.nets_utils import make_source_mask
+
+class RWKVEncoder(AbsEncoder):
+ """RWKV encoder module.
+
+ Based on https://arxiv.org/pdf/2305.13048.pdf.
+
+ Args:
+ vocab_size: Vocabulary size.
+ output_size: Input/Output size.
+ context_size: Context size for WKV computation.
+ linear_size: FeedForward hidden size.
+ attention_size: SelfAttention hidden size.
+ normalization_type: Normalization layer type.
+ normalization_args: Normalization layer arguments.
+ num_blocks: Number of RWKV blocks.
+ embed_dropout_rate: Dropout rate for embedding layer.
+ att_dropout_rate: Dropout rate for the attention module.
+ ffn_dropout_rate: Dropout rate for the feed-forward module.
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 512,
+ context_size: int = 1024,
+ linear_size: Optional[int] = None,
+ attention_size: Optional[int] = None,
+ num_blocks: int = 4,
+ att_dropout_rate: float = 0.0,
+ ffn_dropout_rate: float = 0.0,
+ dropout_rate: float = 0.0,
+ subsampling_factor: int =4,
+ time_reduction_factor: int = 1,
+ kernel: int = 3,
+ ) -> None:
+ """Construct a RWKVEncoder object."""
+ super().__init__()
+
+ assert check_argument_types()
+
+ self.embed = RWKVConvInput(
+ input_size,
+ [output_size//4, output_size//2, output_size],
+ subsampling_factor,
+ conv_kernel_size=kernel,
+ output_size=output_size,
+ )
+
+ self.subsampling_factor = subsampling_factor
+
+ linear_size = output_size * 4 if linear_size is None else linear_size
+ attention_size = output_size if attention_size is None else attention_size
+
+ self.rwkv_blocks = torch.nn.ModuleList(
+ [
+ RWKV(
+ output_size,
+ linear_size,
+ attention_size,
+ context_size,
+ block_id,
+ num_blocks,
+ att_dropout_rate=att_dropout_rate,
+ ffn_dropout_rate=ffn_dropout_rate,
+ dropout_rate=dropout_rate,
+ )
+ for block_id in range(num_blocks)
+ ]
+ )
+
+ self.embed_norm = LayerNorm(output_size)
+ self.final_norm = LayerNorm(output_size)
+
+ self._output_size = output_size
+ self.context_size = context_size
+
+ self.num_blocks = num_blocks
+ self.time_reduction_factor = time_reduction_factor
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def forward(self, x: torch.Tensor, x_len) -> torch.Tensor:
+ """Encode source label sequences.
+
+ Args:
+ x: Encoder input sequences. (B, L)
+
+ Returns:
+ out: Encoder output sequences. (B, U, D)
+
+ """
+ _, length, _ = x.size()
+
+ assert (
+ length <= self.context_size * self.subsampling_factor
+ ), "Context size is too short for current length: %d versus %d" % (
+ length,
+ self.context_size * self.subsampling_factor,
+ )
+ mask = make_source_mask(x_len).to(x.device)
+ x, mask = self.embed(x, mask, None)
+ x = self.embed_norm(x)
+ olens = mask.eq(0).sum(1)
+
+ for block in self.rwkv_blocks:
+ x, _ = block(x)
+ # for streaming inference
+ # xs_pad = self.rwkv_infer(xs_pad)
+
+ x = self.final_norm(x)
+
+ if self.time_reduction_factor > 1:
+ x = x[:,::self.time_reduction_factor,:]
+ olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+ return x, olens, None
+
+ def rwkv_infer(self, xs_pad):
+
+ batch_size = xs_pad.shape[0]
+
+ hidden_sizes = [
+ self._output_size for i in range(5)
+ ]
+
+ state = [
+ torch.zeros(
+ (batch_size, 1, hidden_sizes[i], self.num_rwkv_blocks),
+ dtype=torch.float32,
+ device=self.device,
+ )
+ for i in range(5)
+ ]
+
+ state[4] -= 1e-30
+
+ xs_out = []
+ for t in range(xs_pad.shape[1]):
+ x_t = xs_pad[:,t,:]
+ for idx, block in enumerate(self.rwkv_blocks):
+ x_t, state = block(x_t, state=state)
+ xs_out.append(x_t)
+ xs_out = torch.stack(xs_out, dim=1)
+ return xs_out
diff --git a/funasr/modules/rwkv.py b/funasr/modules/rwkv.py
new file mode 100644
index 0000000..f020828
--- /dev/null
+++ b/funasr/modules/rwkv.py
@@ -0,0 +1,145 @@
+"""Receptance Weighted Key Value (RWKV) block definition.
+
+Based/modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py
+
+"""
+
+from typing import Dict, Optional, Tuple
+
+import torch
+
+from funasr.modules.rwkv_attention import EncoderSelfAttention, DecoderSelfAttention
+from funasr.modules.rwkv_feed_forward import FeedForward
+from funasr.modules.layer_norm import LayerNorm
+
+class RWKV(torch.nn.Module):
+ """RWKV module.
+
+ Args:
+ size: Input/Output size.
+ linear_size: Feed-forward hidden size.
+ attention_size: SelfAttention hidden size.
+ context_size: Context size for WKV computation.
+ block_id: Block index.
+ num_blocks: Number of blocks in the architecture.
+ normalization_class: Normalization layer class.
+ normalization_args: Normalization layer arguments.
+ att_dropout_rate: Dropout rate for the attention module.
+ ffn_dropout_rate: Dropout rate for the feed-forward module.
+
+ """
+
+ def __init__(
+ self,
+ size: int,
+ linear_size: int,
+ attention_size: int,
+ context_size: int,
+ block_id: int,
+ num_blocks: int,
+ att_dropout_rate: float = 0.0,
+ ffn_dropout_rate: float = 0.0,
+ dropout_rate: float = 0.0,
+ ) -> None:
+ """Construct a RWKV object."""
+ super().__init__()
+
+ self.layer_norm_att = LayerNorm(size)
+ self.layer_norm_ffn = LayerNorm(size)
+
+ self.att = EncoderSelfAttention(
+ size, attention_size, context_size, block_id, att_dropout_rate, num_blocks
+ )
+ self.dropout_att = torch.nn.Dropout(p=dropout_rate)
+
+ self.ffn = FeedForward(size, linear_size, block_id, ffn_dropout_rate, num_blocks)
+ self.dropout_ffn = torch.nn.Dropout(p=dropout_rate)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ state: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Compute receptance weighted key value.
+
+ Args:
+ x: RWKV input sequences. (B, L, size)
+ state: Decoder hidden states. [5 x (B, D_att/size, N)]
+
+ Returns:
+ x: RWKV output sequences. (B, L, size)
+ x: Decoder hidden states. [5 x (B, D_att/size, N)]
+
+ """
+ att, state = self.att(self.layer_norm_att(x), state=state)
+ x = x + self.dropout_att(att)
+ ffn, state = self.ffn(self.layer_norm_ffn(x), state=state)
+ x = x + self.dropout_ffn(ffn)
+ return x, state
+
+class RWKVDecoderLayer(torch.nn.Module):
+ """RWKV module.
+
+ Args:
+ size: Input/Output size.
+ linear_size: Feed-forward hidden size.
+ attention_size: SelfAttention hidden size.
+ context_size: Context size for WKV computation.
+ block_id: Block index.
+ num_blocks: Number of blocks in the architecture.
+ normalization_class: Normalization layer class.
+ normalization_args: Normalization layer arguments.
+ att_dropout_rate: Dropout rate for the attention module.
+ ffn_dropout_rate: Dropout rate for the feed-forward module.
+
+ """
+
+ def __init__(
+ self,
+ size: int,
+ linear_size: int,
+ attention_size: int,
+ context_size: int,
+ block_id: int,
+ num_blocks: int,
+ att_dropout_rate: float = 0.0,
+ ffn_dropout_rate: float = 0.0,
+ dropout_rate: float = 0.0,
+ ) -> None:
+ """Construct a RWKV object."""
+ super().__init__()
+
+ self.layer_norm_att = LayerNorm(size)
+ self.layer_norm_ffn = LayerNorm(size)
+
+ self.att = DecoderSelfAttention(
+ size, attention_size, context_size, block_id, att_dropout_rate, num_blocks
+ )
+ self.dropout_att = torch.nn.Dropout(p=dropout_rate)
+
+ self.ffn = FeedForward(size, linear_size, block_id, ffn_dropout_rate, num_blocks)
+ self.dropout_ffn = torch.nn.Dropout(p=dropout_rate)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ state: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ """Compute receptance weighted key value.
+
+ Args:
+ x: RWKV input sequences. (B, L, size)
+ state: Decoder hidden states. [5 x (B, D_att/size, N)]
+
+ Returns:
+ x: RWKV output sequences. (B, L, size)
+ x: Decoder hidden states. [5 x (B, D_att/size, N)]
+
+ """
+ att, state = self.att(self.layer_norm_att(x), state=state)
+ x = x + self.dropout_att(att)
+
+ ffn, state = self.ffn(self.layer_norm_ffn(x), state=state)
+ x = x + self.dropout_ffn(ffn)
+
+ return x, state
diff --git a/funasr/modules/rwkv_attention.py b/funasr/modules/rwkv_attention.py
new file mode 100644
index 0000000..f0c7da3
--- /dev/null
+++ b/funasr/modules/rwkv_attention.py
@@ -0,0 +1,632 @@
+"""Attention (time mixing) modules for RWKV block.
+
+Based/Modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py.
+
+Some variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py.
+
+""" # noqa
+
+import math
+from importlib.util import find_spec
+from pathlib import Path
+from typing import List, Optional, Tuple, Union
+
+import torch
+
+wkv_kernel_encoder = None
+wkv_kernel_decoder = None
+
+class WKVLinearAttentionEncoder(torch.autograd.Function):
+ """WKVLinearAttention function definition."""
+
+ @staticmethod
+ def forward(
+ ctx,
+ time_decay: torch.Tensor,
+ time_first: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.tensor,
+ ) -> torch.Tensor:
+ """WKVLinearAttention function forward pass.
+
+ Args:
+ time_decay: Channel-wise time decay vector. (D_att)
+ time_first: Channel-wise time first vector. (D_att)
+ key: Key tensor. (B, U, D_att)
+ value: Value tensor. (B, U, D_att)
+
+ Returns:
+ out: Weighted Key-Value tensor. (B, U, D_att)
+
+ """
+ batch, length, dim = key.size()
+
+ assert length <= wkv_kernel_encoder.context_size, (
+ f"Cannot process key of length {length} while context_size "
+ f"is ({wkv_kernel_encoder.context_size}). Limit should be increased."
+ )
+
+ assert batch * dim % min(dim, 32) == 0, (
+ f"batch size ({batch}) by dimension ({dim}) should be a multiple of "
+ f"{min(dim, 32)}"
+ )
+
+ ctx.input_dtype = key.dtype
+
+ time_decay = -torch.exp(time_decay.float().contiguous())
+ time_first = time_first.float().contiguous()
+
+ key = key.float().contiguous()
+ value = value.float().contiguous()
+
+ out = torch.empty_like(key, memory_format=torch.contiguous_format)
+
+ wkv_kernel_encoder.forward(time_decay, time_first, key, value, out)
+ ctx.save_for_backward(time_decay, time_first, key, value, out)
+
+ return out
+
+ @staticmethod
+ def backward(
+ ctx, grad_output: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """WKVLinearAttention function backward pass.
+
+ Args:
+ grad_output: Output gradient. (B, U, D_att)
+
+ Returns:
+ grad_time_decay: Gradient for channel-wise time decay vector. (D_att)
+ grad_time_first: Gradient for channel-wise time first vector. (D_att)
+ grad_key: Gradient for key tensor. (B, U, D_att)
+ grad_value: Gradient for value tensor. (B, U, D_att)
+
+ """
+ time_decay, time_first, key, value, output = ctx.saved_tensors
+ grad_dtype = ctx.input_dtype
+
+ batch, _, dim = key.size()
+
+ grad_time_decay = torch.empty(
+ (batch, dim),
+ memory_format=torch.contiguous_format,
+ dtype=time_decay.dtype,
+ device=time_decay.device,
+ )
+
+ grad_time_first = torch.empty(
+ (batch, dim),
+ memory_format=torch.contiguous_format,
+ dtype=time_decay.dtype,
+ device=time_decay.device,
+ )
+
+ grad_key = torch.empty_like(key, memory_format=torch.contiguous_format)
+ grad_value = torch.empty_like(value, memory_format=torch.contiguous_format)
+
+ wkv_kernel_encoder.backward(
+ time_decay,
+ time_first,
+ key,
+ value,
+ output,
+ grad_output.contiguous(),
+ grad_time_decay,
+ grad_time_first,
+ grad_key,
+ grad_value,
+ )
+
+ grad_time_decay = torch.sum(grad_time_decay, dim=0)
+ grad_time_first = torch.sum(grad_time_first, dim=0)
+
+ return (
+ grad_time_decay,
+ grad_time_first,
+ grad_key,
+ grad_value,
+ )
+
+class WKVLinearAttentionDecoder(torch.autograd.Function):
+ """WKVLinearAttention function definition."""
+
+ @staticmethod
+ def forward(
+ ctx,
+ time_decay: torch.Tensor,
+ time_first: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.tensor,
+ ) -> torch.Tensor:
+ """WKVLinearAttention function forward pass.
+
+ Args:
+ time_decay: Channel-wise time decay vector. (D_att)
+ time_first: Channel-wise time first vector. (D_att)
+ key: Key tensor. (B, U, D_att)
+ value: Value tensor. (B, U, D_att)
+
+ Returns:
+ out: Weighted Key-Value tensor. (B, U, D_att)
+
+ """
+ batch, length, dim = key.size()
+
+ assert length <= wkv_kernel_decoder.context_size, (
+ f"Cannot process key of length {length} while context_size "
+ f"is ({wkv_kernel.context_size}). Limit should be increased."
+ )
+
+ assert batch * dim % min(dim, 32) == 0, (
+ f"batch size ({batch}) by dimension ({dim}) should be a multiple of "
+ f"{min(dim, 32)}"
+ )
+
+ ctx.input_dtype = key.dtype
+
+ time_decay = -torch.exp(time_decay.float().contiguous())
+ time_first = time_first.float().contiguous()
+
+ key = key.float().contiguous()
+ value = value.float().contiguous()
+
+ out = torch.empty_like(key, memory_format=torch.contiguous_format)
+
+ wkv_kernel_decoder.forward(time_decay, time_first, key, value, out)
+ ctx.save_for_backward(time_decay, time_first, key, value, out)
+
+ return out
+
+ @staticmethod
+ def backward(
+ ctx, grad_output: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """WKVLinearAttention function backward pass.
+
+ Args:
+ grad_output: Output gradient. (B, U, D_att)
+
+ Returns:
+ grad_time_decay: Gradient for channel-wise time decay vector. (D_att)
+ grad_time_first: Gradient for channel-wise time first vector. (D_att)
+ grad_key: Gradient for key tensor. (B, U, D_att)
+ grad_value: Gradient for value tensor. (B, U, D_att)
+
+ """
+ time_decay, time_first, key, value, output = ctx.saved_tensors
+ grad_dtype = ctx.input_dtype
+
+ batch, _, dim = key.size()
+
+ grad_time_decay = torch.empty(
+ (batch, dim),
+ memory_format=torch.contiguous_format,
+ dtype=time_decay.dtype,
+ device=time_decay.device,
+ )
+
+ grad_time_first = torch.empty(
+ (batch, dim),
+ memory_format=torch.contiguous_format,
+ dtype=time_decay.dtype,
+ device=time_decay.device,
+ )
+
+ grad_key = torch.empty_like(key, memory_format=torch.contiguous_format)
+ grad_value = torch.empty_like(value, memory_format=torch.contiguous_format)
+
+ wkv_kernel_decoder.backward(
+ time_decay,
+ time_first,
+ key,
+ value,
+ output,
+ grad_output.contiguous(),
+ grad_time_decay,
+ grad_time_first,
+ grad_key,
+ grad_value,
+ )
+
+ grad_time_decay = torch.sum(grad_time_decay, dim=0)
+ grad_time_first = torch.sum(grad_time_first, dim=0)
+
+ return (
+ grad_time_decay,
+ grad_time_first,
+ grad_key,
+ grad_value,
+ )
+
+def load_encoder_wkv_kernel(context_size: int) -> None:
+ """Load WKV CUDA kernel.
+
+ Args:
+ context_size: Context size.
+
+ """
+ from torch.utils.cpp_extension import load
+
+ global wkv_kernel_encoder
+
+ if wkv_kernel_encoder is not None and wkv_kernel_encoder.context_size == context_size:
+ return
+
+ if find_spec("ninja") is None:
+ raise ImportError(
+ "Ninja package was not found. WKV kernel module can't be loaded "
+ "for training. Please, 'pip install ninja' in your environment."
+ )
+
+ if not torch.cuda.is_available():
+ raise ImportError(
+ "CUDA is currently a requirement for WKV kernel loading. "
+ "Please set your devices properly and launch again."
+ )
+
+ kernel_folder = Path(__file__).resolve().parent / "cuda_encoder"
+ kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu"]]
+
+ kernel_cflags = [
+ "-res-usage",
+ "--maxrregcount 60",
+ "--use_fast_math",
+ "-O3",
+ "-Xptxas -O3",
+ f"-DTmax={context_size}",
+ ]
+ wkv_kernel_encoder = load(
+ name=f"encoder_wkv_{context_size}",
+ sources=kernel_files,
+ verbose=True,
+ extra_cuda_cflags=kernel_cflags,
+ )
+ wkv_kernel_encoder.context_size = context_size
+
+def load_decoder_wkv_kernel(context_size: int) -> None:
+ """Load WKV CUDA kernel.
+
+ Args:
+ context_size: Context size.
+
+ """
+ from torch.utils.cpp_extension import load
+
+ global wkv_kernel_decoder
+
+ if wkv_kernel_decoder is not None and wkv_kernel_decoder.context_size == context_size:
+ return
+
+ if find_spec("ninja") is None:
+ raise ImportError(
+ "Ninja package was not found. WKV kernel module can't be loaded "
+ "for training. Please, 'pip install ninja' in your environment."
+ )
+
+ if not torch.cuda.is_available():
+ raise ImportError(
+ "CUDA is currently a requirement for WKV kernel loading. "
+ "Please set your devices properly and launch again."
+ )
+
+ kernel_folder = Path(__file__).resolve().parent / "cuda_decoder"
+ kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu"]]
+
+ kernel_cflags = [
+ "-res-usage",
+ "--maxrregcount 60",
+ "--use_fast_math",
+ "-O3",
+ "-Xptxas -O3",
+ f"-DTmax={context_size}",
+ ]
+ wkv_kernel_decoder = load(
+ name=f"decoder_wkv_{context_size}",
+ sources=kernel_files,
+ verbose=True,
+ extra_cuda_cflags=kernel_cflags,
+ )
+ wkv_kernel_decoder.context_size = context_size
+
+class SelfAttention(torch.nn.Module):
+ """SelfAttention module definition.
+
+ Args:
+ size: Input/Output size.
+ attention_size: Attention hidden size.
+ context_size: Context size for WKV kernel.
+ block_id: Block index.
+ num_blocks: Number of blocks in the architecture.
+
+ """
+
+ def __init__(
+ self,
+ size: int,
+ attention_size: int,
+ block_id: int,
+ dropout_rate: float,
+ num_blocks: int,
+ ) -> None:
+ """Construct a SelfAttention object."""
+ super().__init__()
+ self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1))
+
+ self.time_decay = torch.nn.Parameter(torch.empty(attention_size))
+ self.time_first = torch.nn.Parameter(torch.empty(attention_size))
+
+ self.time_mix_key = torch.nn.Parameter(torch.empty(1, 1, size))
+ self.time_mix_value = torch.nn.Parameter(torch.empty(1, 1, size))
+ self.time_mix_receptance = torch.nn.Parameter(torch.empty(1, 1, size))
+
+ self.proj_key = torch.nn.Linear(size, attention_size, bias=True)
+ self.proj_value = torch.nn.Linear(size, attention_size, bias=True)
+ self.proj_receptance = torch.nn.Linear(size, attention_size, bias=True)
+
+ self.proj_output = torch.nn.Linear(attention_size, size, bias=True)
+
+ self.block_id = block_id
+
+ self.reset_parameters(size, attention_size, block_id, num_blocks)
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+
+ def reset_parameters(
+ self, size: int, attention_size: int, block_id: int, num_blocks: int
+ ) -> None:
+ """Reset module parameters.
+
+ Args:
+ size: Block size.
+ attention_size: Attention hidden size.
+ block_id: Block index.
+ num_blocks: Number of blocks in the architecture.
+
+ """
+ ratio_0_to_1 = block_id / (num_blocks - 1)
+ ratio_1_to_almost0 = 1.0 - (block_id / num_blocks)
+
+ time_weight = torch.ones(1, 1, size)
+
+ for i in range(size):
+ time_weight[0, 0, i] = i / size
+
+ decay_speed = [
+ -5 + 8 * (h / (attention_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
+ for h in range(attention_size)
+ ]
+ decay_speed = torch.tensor(
+ decay_speed, dtype=self.time_decay.dtype, device=self.time_decay.device
+ )
+
+ zigzag = (
+ torch.tensor(
+ [(i + 1) % 3 - 1 for i in range(attention_size)],
+ dtype=self.time_first.dtype,
+ device=self.time_first.device,
+ )
+ * 0.5
+ )
+
+ with torch.no_grad():
+ self.time_decay.data = decay_speed
+ self.time_first.data = torch.ones_like(
+ self.time_first * math.log(0.3) + zigzag
+ )
+
+ self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
+ self.time_mix_value.data = (
+ torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
+ )
+ self.time_mix_receptance.data = torch.pow(
+ time_weight, 0.5 * ratio_1_to_almost0
+ )
+
+ @torch.no_grad()
+ def wkv_linear_attention(
+ self,
+ time_decay: torch.Tensor,
+ time_first: torch.Tensor,
+ key: torch.Tensor,
+ value: torch.Tensor,
+ state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
+ """Compute WKV with state (i.e.: for inference).
+
+ Args:
+ time_decay: Channel-wise time decay vector. (D_att)
+ time_first: Channel-wise time first vector. (D_att)
+ key: Key tensor. (B, 1, D_att)
+ value: Value tensor. (B, 1, D_att)
+ state: Decoder hidden states. [3 x (B, D_att)]
+
+ Returns:
+ output: Weighted Key-Value. (B, 1, D_att)
+ state: Decoder hidden states. [3 x (B, 1, D_att)]
+
+ """
+ num_state, den_state, max_state = state
+
+ max_for_output = torch.maximum(max_state, (time_first + key))
+
+ e1 = torch.exp(max_state - max_for_output)
+ e2 = torch.exp((time_first + key) - max_for_output)
+
+ numerator = e1 * num_state + e2 * value
+ denominator = e1 * den_state + e2
+
+ max_for_state = torch.maximum(key, (max_state + time_decay))
+
+ e1 = torch.exp((max_state + time_decay) - max_for_state)
+ e2 = torch.exp(key - max_for_state)
+
+ wkv = numerator / denominator
+
+ state = [e1 * num_state + e2 * value, e1 * den_state + e2, max_for_state]
+
+ return wkv, state
+
+
+class DecoderSelfAttention(SelfAttention):
+ """SelfAttention module definition.
+
+ Args:
+ size: Input/Output size.
+ attention_size: Attention hidden size.
+ context_size: Context size for WKV kernel.
+ block_id: Block index.
+ num_blocks: Number of blocks in the architecture.
+
+ """
+
+ def __init__(
+ self,
+ size: int,
+ attention_size: int,
+ context_size: int,
+ block_id: int,
+ dropout_rate: float,
+ num_blocks: int,
+ ) -> None:
+ """Construct a SelfAttention object."""
+ super().__init__(
+ size,
+ attention_size,
+ block_id,
+ dropout_rate,
+ num_blocks
+ )
+ load_decoder_wkv_kernel(context_size)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ state: Optional[List[torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
+ """Compute time mixing.
+
+ Args:
+ x: SelfAttention input sequences. (B, U, size)
+ state: Decoder hidden states. [5 x (B, 1, D_att, N)]
+
+ Returns:
+ x: SelfAttention output sequences. (B, U, size)
+
+ """
+ shifted_x = (
+ self.time_shift(x) if state is None else state[1][..., self.block_id]
+ )
+
+ key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
+ value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value)
+ receptance = x * self.time_mix_receptance + shifted_x * (
+ 1 - self.time_mix_receptance
+ )
+
+ key = self.proj_key(key)
+ value = self.proj_value(value)
+ receptance = torch.sigmoid(self.proj_receptance(receptance))
+
+ if state is not None:
+ state[1][..., self.block_id] = x
+
+ wkv, att_state = self.wkv_linear_attention(
+ self.time_decay,
+ self.time_first,
+ key,
+ value,
+ tuple(s[..., self.block_id] for s in state[2:]),
+ )
+
+ state[2][..., self.block_id] = att_state[0]
+ state[3][..., self.block_id] = att_state[1]
+ state[4][..., self.block_id] = att_state[2]
+ else:
+ wkv = WKVLinearAttentionDecoder.apply(self.time_decay, self.time_first, key, value)
+
+ wkv = self.dropout(wkv)
+ x = self.proj_output(receptance * wkv)
+
+ return x, state
+
+class EncoderSelfAttention(SelfAttention):
+ """SelfAttention module definition.
+
+ Args:
+ size: Input/Output size.
+ attention_size: Attention hidden size.
+ context_size: Context size for WKV kernel.
+ block_id: Block index.
+ num_blocks: Number of blocks in the architecture.
+
+ """
+
+ def __init__(
+ self,
+ size: int,
+ attention_size: int,
+ context_size: int,
+ block_id: int,
+ dropout_rate: float,
+ num_blocks: int,
+ ) -> None:
+ """Construct a SelfAttention object."""
+ super().__init__(
+ size,
+ attention_size,
+ block_id,
+ dropout_rate,
+ num_blocks
+ )
+ load_encoder_wkv_kernel(context_size)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ state: Optional[List[torch.Tensor]] = None,
+ ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
+ """Compute time mixing.
+
+ Args:
+ x: SelfAttention input sequences. (B, U, size)
+ state: Decoder hidden states. [5 x (B, 1, D_att, N)]
+
+ Returns:
+ x: SelfAttention output sequences. (B, U, size)
+
+ """
+ shifted_x = (
+ self.time_shift(x) if state is None else state[1][..., self.block_id]
+ )
+
+ key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
+ value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value)
+ receptance = x * self.time_mix_receptance + shifted_x * (
+ 1 - self.time_mix_receptance
+ )
+
+ key = self.proj_key(key)
+ value = self.proj_value(value)
+ receptance = torch.sigmoid(self.proj_receptance(receptance))
+
+ if state is not None:
+ state[1][..., self.block_id] = x
+
+ wkv, att_state = self.wkv_linear_attention(
+ self.time_decay,
+ self.time_first,
+ key,
+ value,
+ tuple(s[..., self.block_id] for s in state[2:]),
+ )
+
+ state[2][..., self.block_id] = att_state[0]
+ state[3][..., self.block_id] = att_state[1]
+ state[4][..., self.block_id] = att_state[2]
+ else:
+ wkv = WKVLinearAttentionEncoder.apply(self.time_decay, self.time_first, key, value)
+
+ wkv = self.dropout(wkv)
+ x = self.proj_output(receptance * wkv)
+
+ return x, state
+
diff --git a/funasr/modules/rwkv_feed_forward.py b/funasr/modules/rwkv_feed_forward.py
new file mode 100644
index 0000000..ddb4285
--- /dev/null
+++ b/funasr/modules/rwkv_feed_forward.py
@@ -0,0 +1,97 @@
+"""Feed-forward (channel mixing) module for RWKV block.
+
+Based/Modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py
+
+Some variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py.
+
+""" # noqa
+
+from typing import List, Optional, Tuple
+
+import torch
+
+
+class FeedForward(torch.nn.Module):
+ """FeedForward module definition.
+
+ Args:
+ size: Input/Output size.
+ hidden_size: Hidden size.
+ block_id: Block index.
+ num_blocks: Number of blocks in the architecture.
+
+ """
+
+ def __init__(
+ self, size: int, hidden_size: int, block_id: int, dropout_rate: float, num_blocks: int
+ ) -> None:
+ """Construct a FeedForward object."""
+ super().__init__()
+
+ self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1))
+
+ self.time_mix_key = torch.nn.Parameter(torch.empty(1, 1, size))
+ self.time_mix_receptance = torch.nn.Parameter(torch.empty(1, 1, size))
+
+ self.proj_key = torch.nn.Linear(size, hidden_size, bias=True)
+ self.proj_value = torch.nn.Linear(hidden_size, size, bias=True)
+ self.proj_receptance = torch.nn.Linear(size, size, bias=True)
+
+ self.block_id = block_id
+
+ self.reset_parameters(size, block_id, num_blocks)
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
+
+ def reset_parameters(self, size: int, block_id: int, num_blocks: int) -> None:
+ """Reset module parameters.
+
+ Args:
+ size: Block size.
+ block_id: Block index.
+ num_blocks: Number of blocks in the architecture.
+
+ """
+ ratio_1_to_almost0 = 1.0 - (block_id / num_blocks)
+
+ time_weight = torch.ones(1, 1, size)
+
+ for i in range(size):
+ time_weight[0, 0, i] = i / size
+
+ with torch.no_grad():
+ self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
+ self.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)
+
+ def forward(
+ self, x: torch.Tensor, state: Optional[List[torch.Tensor]] = None
+ ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
+ """Compute channel mixing.
+
+ Args:
+ x: FeedForward input sequences. (B, U, size)
+ state: Decoder hidden state. [5 x (B, 1, size, N)]
+
+ Returns:
+ x: FeedForward output sequences. (B, U, size)
+ state: Decoder hidden state. [5 x (B, 1, size, N)]
+
+ """
+ shifted_x = (
+ self.time_shift(x) if state is None else state[0][..., self.block_id]
+ )
+
+ key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
+ receptance = x * self.time_mix_receptance + shifted_x * (
+ 1 - self.time_mix_receptance
+ )
+
+ key = torch.square(torch.relu(self.proj_key(key)))
+ value = self.proj_value(self.dropout(key))
+ receptance = torch.sigmoid(self.proj_receptance(receptance))
+
+ if state is not None:
+ state[0][..., self.block_id] = x
+
+ x = receptance * value
+
+ return x, state
diff --git a/funasr/modules/rwkv_subsampling.py b/funasr/modules/rwkv_subsampling.py
new file mode 100644
index 0000000..4277093
--- /dev/null
+++ b/funasr/modules/rwkv_subsampling.py
@@ -0,0 +1,190 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Shigeki Karita
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Subsampling layer definition."""
+import numpy as np
+import torch
+import torch.nn.functional as F
+from funasr.modules.embedding import PositionalEncoding
+import logging
+from funasr.modules.streaming_utils.utils import sequence_mask
+from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len
+from typing import Optional, Tuple, Union
+import math
+
+class TooShortUttError(Exception):
+ """Raised when the utt is too short for subsampling.
+
+ Args:
+ message (str): Message for error catch
+ actual_size (int): the short size that cannot pass the subsampling
+ limit (int): the limit size for subsampling
+
+ """
+
+ def __init__(self, message, actual_size, limit):
+ """Construct a TooShortUttError for error handler."""
+ super().__init__(message)
+ self.actual_size = actual_size
+ self.limit = limit
+
+
+def check_short_utt(ins, size):
+ """Check if the utterance is too short for subsampling."""
+ if isinstance(ins, Conv2dSubsampling2) and size < 3:
+ return True, 3
+ if isinstance(ins, Conv2dSubsampling) and size < 7:
+ return True, 7
+ if isinstance(ins, Conv2dSubsampling6) and size < 11:
+ return True, 11
+ if isinstance(ins, Conv2dSubsampling8) and size < 15:
+ return True, 15
+ return False, -1
+
+
+class RWKVConvInput(torch.nn.Module):
+ """Streaming ConvInput module definition.
+ Args:
+ input_size: Input size.
+ conv_size: Convolution size.
+ subsampling_factor: Subsampling factor.
+ output_size: Block output dimension.
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ conv_size: Union[int, Tuple],
+ subsampling_factor: int = 4,
+ conv_kernel_size: int = 3,
+ output_size: Optional[int] = None,
+ ) -> None:
+ """Construct a ConvInput object."""
+ super().__init__()
+ if subsampling_factor == 1:
+ conv_size1, conv_size2, conv_size3 = conv_size
+
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=[1, 2], padding=(conv_kernel_size-1)//2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=[1, 2], padding=(conv_kernel_size-1)//2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size2, conv_size3, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size3, conv_size3, conv_kernel_size, stride=[1, 2], padding=(conv_kernel_size-1)//2),
+ torch.nn.ReLU(),
+ )
+
+ output_proj = conv_size3 * ((input_size // 2) // 2)
+
+ self.subsampling_factor = 1
+
+ self.stride_1 = 1
+
+ self.create_new_mask = self.create_new_vgg_mask
+
+ else:
+ conv_size1, conv_size2, conv_size3 = conv_size
+
+ kernel_1 = int(subsampling_factor / 2)
+
+ self.conv = torch.nn.Sequential(
+ torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=[kernel_1, 2], padding=(conv_kernel_size-1)//2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=[2, 2], padding=(conv_kernel_size-1)//2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size2, conv_size3, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
+ torch.nn.ReLU(),
+ torch.nn.Conv2d(conv_size3, conv_size3, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
+ torch.nn.ReLU(),
+ )
+
+ output_proj = conv_size3 * ((input_size // 2) // 2)
+
+ self.subsampling_factor = subsampling_factor
+
+ self.create_new_mask = self.create_new_vgg_mask
+
+ self.stride_1 = kernel_1
+
+ self.min_frame_length = 7
+
+ if output_size is not None:
+ self.output = torch.nn.Linear(output_proj, output_size)
+ self.output_size = output_size
+ else:
+ self.output = None
+ self.output_size = output_proj
+
+ def forward(
+ self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor]
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encode input sequences.
+ Args:
+ x: ConvInput input sequences. (B, T, D_feats)
+ mask: Mask of input sequences. (B, 1, T)
+ Returns:
+ x: ConvInput output sequences. (B, sub(T), D_out)
+ mask: Mask of output sequences. (B, 1, sub(T))
+ """
+ if mask is not None:
+ mask = self.create_new_mask(mask)
+ olens = max(mask.eq(0).sum(1))
+
+ b, t, f = x.size()
+ x = x.unsqueeze(1) # (b. 1. t. f)
+
+ if chunk_size is not None:
+ max_input_length = int(
+ chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) ))
+ )
+ x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x)
+ x = list(x)
+ x = torch.stack(x, dim=0)
+ N_chunks = max_input_length // ( chunk_size * self.subsampling_factor)
+ x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f)
+
+ x = self.conv(x)
+
+ _, c, _, f = x.size()
+ if chunk_size is not None:
+ x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:]
+ else:
+ x = x.transpose(1, 2).contiguous().view(b, -1, c * f)
+
+ if self.output is not None:
+ x = self.output(x)
+
+ return x, mask[:,:olens][:,:x.size(1)]
+
+ def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor:
+ """Create a new mask for VGG output sequences.
+ Args:
+ mask: Mask of input sequences. (B, T)
+ Returns:
+ mask: Mask of output sequences. (B, sub(T))
+ """
+ if self.subsampling_factor > 1:
+ return mask[:, ::2][:, ::self.stride_1]
+ else:
+ return mask
+
+ def get_size_before_subsampling(self, size: int) -> int:
+ """Return the original size before subsampling for a given size.
+ Args:
+ size: Number of frames after subsampling.
+ Returns:
+ : Number of frames before subsampling.
+ """
+ return size * self.subsampling_factor
--
Gitblit v1.9.1