| | |
| | | from funasr.models.encoder.branchformer_encoder import BranchformerEncoder |
| | | from funasr.models.encoder.e_branchformer_encoder import EBranchformerEncoder |
| | | from funasr.models.encoder.transformer_encoder import TransformerEncoder |
| | | from funasr.models.encoder.rwkv_encoder import RWKVEncoder |
| | | from funasr.models.frontend.default import DefaultFrontend |
| | | from funasr.models.frontend.default import MultiChannelFrontend |
| | | from funasr.models.frontend.fused import FusedFrontends |
| | |
| | | e_branchformer=EBranchformerEncoder, |
| | | mfcca_enc=MFCCAEncoder, |
| | | chunk_conformer=ConformerChunkEncoder, |
| | | rwkv=RWKVEncoder, |
| | | ), |
| | | default="rnn", |
| | | ) |
| New file |
| | |
| | | """RWKV encoder definition for Transducer models.""" |
| | | |
| | | import math |
| | | from typing import Dict, List, Optional, Tuple |
| | | |
| | | import torch |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.modules.rwkv import RWKV |
| | | from funasr.modules.layer_norm import LayerNorm |
| | | from funasr.modules.rwkv_subsampling import RWKVConvInput |
| | | from funasr.modules.nets_utils import make_source_mask |
| | | |
| | | class RWKVEncoder(AbsEncoder): |
| | | """RWKV encoder module. |
| | | |
| | | Based on https://arxiv.org/pdf/2305.13048.pdf. |
| | | |
| | | Args: |
| | | vocab_size: Vocabulary size. |
| | | output_size: Input/Output size. |
| | | context_size: Context size for WKV computation. |
| | | linear_size: FeedForward hidden size. |
| | | attention_size: SelfAttention hidden size. |
| | | normalization_type: Normalization layer type. |
| | | normalization_args: Normalization layer arguments. |
| | | num_blocks: Number of RWKV blocks. |
| | | embed_dropout_rate: Dropout rate for embedding layer. |
| | | att_dropout_rate: Dropout rate for the attention module. |
| | | ffn_dropout_rate: Dropout rate for the feed-forward module. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | input_size: int, |
| | | output_size: int = 512, |
| | | context_size: int = 1024, |
| | | linear_size: Optional[int] = None, |
| | | attention_size: Optional[int] = None, |
| | | num_blocks: int = 4, |
| | | att_dropout_rate: float = 0.0, |
| | | ffn_dropout_rate: float = 0.0, |
| | | dropout_rate: float = 0.0, |
| | | subsampling_factor: int =4, |
| | | time_reduction_factor: int = 1, |
| | | kernel: int = 3, |
| | | ) -> None: |
| | | """Construct a RWKVEncoder object.""" |
| | | super().__init__() |
| | | |
| | | assert check_argument_types() |
| | | |
| | | self.embed = RWKVConvInput( |
| | | input_size, |
| | | [output_size//4, output_size//2, output_size], |
| | | subsampling_factor, |
| | | conv_kernel_size=kernel, |
| | | output_size=output_size, |
| | | ) |
| | | |
| | | self.subsampling_factor = subsampling_factor |
| | | |
| | | linear_size = output_size * 4 if linear_size is None else linear_size |
| | | attention_size = output_size if attention_size is None else attention_size |
| | | |
| | | self.rwkv_blocks = torch.nn.ModuleList( |
| | | [ |
| | | RWKV( |
| | | output_size, |
| | | linear_size, |
| | | attention_size, |
| | | context_size, |
| | | block_id, |
| | | num_blocks, |
| | | att_dropout_rate=att_dropout_rate, |
| | | ffn_dropout_rate=ffn_dropout_rate, |
| | | dropout_rate=dropout_rate, |
| | | ) |
| | | for block_id in range(num_blocks) |
| | | ] |
| | | ) |
| | | |
| | | self.embed_norm = LayerNorm(output_size) |
| | | self.final_norm = LayerNorm(output_size) |
| | | |
| | | self._output_size = output_size |
| | | self.context_size = context_size |
| | | |
| | | self.num_blocks = num_blocks |
| | | self.time_reduction_factor = time_reduction_factor |
| | | |
| | | def output_size(self) -> int: |
| | | return self._output_size |
| | | |
| | | def forward(self, x: torch.Tensor, x_len) -> torch.Tensor: |
| | | """Encode source label sequences. |
| | | |
| | | Args: |
| | | x: Encoder input sequences. (B, L) |
| | | |
| | | Returns: |
| | | out: Encoder output sequences. (B, U, D) |
| | | |
| | | """ |
| | | _, length, _ = x.size() |
| | | |
| | | assert ( |
| | | length <= self.context_size * self.subsampling_factor |
| | | ), "Context size is too short for current length: %d versus %d" % ( |
| | | length, |
| | | self.context_size * self.subsampling_factor, |
| | | ) |
| | | mask = make_source_mask(x_len).to(x.device) |
| | | x, mask = self.embed(x, mask, None) |
| | | x = self.embed_norm(x) |
| | | olens = mask.eq(0).sum(1) |
| | | |
| | | for block in self.rwkv_blocks: |
| | | x, _ = block(x) |
| | | # for streaming inference |
| | | # xs_pad = self.rwkv_infer(xs_pad) |
| | | |
| | | x = self.final_norm(x) |
| | | |
| | | if self.time_reduction_factor > 1: |
| | | x = x[:,::self.time_reduction_factor,:] |
| | | olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 |
| | | |
| | | return x, olens, None |
| | | |
| | | def rwkv_infer(self, xs_pad): |
| | | |
| | | batch_size = xs_pad.shape[0] |
| | | |
| | | hidden_sizes = [ |
| | | self._output_size for i in range(5) |
| | | ] |
| | | |
| | | state = [ |
| | | torch.zeros( |
| | | (batch_size, 1, hidden_sizes[i], self.num_rwkv_blocks), |
| | | dtype=torch.float32, |
| | | device=self.device, |
| | | ) |
| | | for i in range(5) |
| | | ] |
| | | |
| | | state[4] -= 1e-30 |
| | | |
| | | xs_out = [] |
| | | for t in range(xs_pad.shape[1]): |
| | | x_t = xs_pad[:,t,:] |
| | | for idx, block in enumerate(self.rwkv_blocks): |
| | | x_t, state = block(x_t, state=state) |
| | | xs_out.append(x_t) |
| | | xs_out = torch.stack(xs_out, dim=1) |
| | | return xs_out |
| New file |
| | |
| | | """Receptance Weighted Key Value (RWKV) block definition. |
| | | |
| | | Based/modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py |
| | | |
| | | """ |
| | | |
| | | from typing import Dict, Optional, Tuple |
| | | |
| | | import torch |
| | | |
| | | from funasr.modules.rwkv_attention import EncoderSelfAttention, DecoderSelfAttention |
| | | from funasr.modules.rwkv_feed_forward import FeedForward |
| | | from funasr.modules.layer_norm import LayerNorm |
| | | |
| | | class RWKV(torch.nn.Module): |
| | | """RWKV module. |
| | | |
| | | Args: |
| | | size: Input/Output size. |
| | | linear_size: Feed-forward hidden size. |
| | | attention_size: SelfAttention hidden size. |
| | | context_size: Context size for WKV computation. |
| | | block_id: Block index. |
| | | num_blocks: Number of blocks in the architecture. |
| | | normalization_class: Normalization layer class. |
| | | normalization_args: Normalization layer arguments. |
| | | att_dropout_rate: Dropout rate for the attention module. |
| | | ffn_dropout_rate: Dropout rate for the feed-forward module. |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | size: int, |
| | | linear_size: int, |
| | | attention_size: int, |
| | | context_size: int, |
| | | block_id: int, |
| | | num_blocks: int, |
| | | att_dropout_rate: float = 0.0, |
| | | ffn_dropout_rate: float = 0.0, |
| | | dropout_rate: float = 0.0, |
| | | ) -> None: |
| | | """Construct a RWKV object.""" |
| | | super().__init__() |
| | | |
| | | self.layer_norm_att = LayerNorm(size) |
| | | self.layer_norm_ffn = LayerNorm(size) |
| | | |
| | | self.att = EncoderSelfAttention( |
| | | size, attention_size, context_size, block_id, att_dropout_rate, num_blocks |
| | | ) |
| | | self.dropout_att = torch.nn.Dropout(p=dropout_rate) |
| | | |
| | | self.ffn = FeedForward(size, linear_size, block_id, ffn_dropout_rate, num_blocks) |
| | | self.dropout_ffn = torch.nn.Dropout(p=dropout_rate) |
| | | |
| | | def forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | state: Optional[torch.Tensor] = None, |
| | | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| | | """Compute receptance weighted key value. |
| | | |
| | | Args: |
| | | x: RWKV input sequences. (B, L, size) |
| | | state: Decoder hidden states. [5 x (B, D_att/size, N)] |
| | | |
| | | Returns: |
| | | x: RWKV output sequences. (B, L, size) |
| | | x: Decoder hidden states. [5 x (B, D_att/size, N)] |
| | | |
| | | """ |
| | | att, state = self.att(self.layer_norm_att(x), state=state) |
| | | x = x + self.dropout_att(att) |
| | | ffn, state = self.ffn(self.layer_norm_ffn(x), state=state) |
| | | x = x + self.dropout_ffn(ffn) |
| | | return x, state |
| | | |
| | | class RWKVDecoderLayer(torch.nn.Module): |
| | | """RWKV module. |
| | | |
| | | Args: |
| | | size: Input/Output size. |
| | | linear_size: Feed-forward hidden size. |
| | | attention_size: SelfAttention hidden size. |
| | | context_size: Context size for WKV computation. |
| | | block_id: Block index. |
| | | num_blocks: Number of blocks in the architecture. |
| | | normalization_class: Normalization layer class. |
| | | normalization_args: Normalization layer arguments. |
| | | att_dropout_rate: Dropout rate for the attention module. |
| | | ffn_dropout_rate: Dropout rate for the feed-forward module. |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | size: int, |
| | | linear_size: int, |
| | | attention_size: int, |
| | | context_size: int, |
| | | block_id: int, |
| | | num_blocks: int, |
| | | att_dropout_rate: float = 0.0, |
| | | ffn_dropout_rate: float = 0.0, |
| | | dropout_rate: float = 0.0, |
| | | ) -> None: |
| | | """Construct a RWKV object.""" |
| | | super().__init__() |
| | | |
| | | self.layer_norm_att = LayerNorm(size) |
| | | self.layer_norm_ffn = LayerNorm(size) |
| | | |
| | | self.att = DecoderSelfAttention( |
| | | size, attention_size, context_size, block_id, att_dropout_rate, num_blocks |
| | | ) |
| | | self.dropout_att = torch.nn.Dropout(p=dropout_rate) |
| | | |
| | | self.ffn = FeedForward(size, linear_size, block_id, ffn_dropout_rate, num_blocks) |
| | | self.dropout_ffn = torch.nn.Dropout(p=dropout_rate) |
| | | |
| | | def forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | state: Optional[torch.Tensor] = None, |
| | | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| | | """Compute receptance weighted key value. |
| | | |
| | | Args: |
| | | x: RWKV input sequences. (B, L, size) |
| | | state: Decoder hidden states. [5 x (B, D_att/size, N)] |
| | | |
| | | Returns: |
| | | x: RWKV output sequences. (B, L, size) |
| | | x: Decoder hidden states. [5 x (B, D_att/size, N)] |
| | | |
| | | """ |
| | | att, state = self.att(self.layer_norm_att(x), state=state) |
| | | x = x + self.dropout_att(att) |
| | | |
| | | ffn, state = self.ffn(self.layer_norm_ffn(x), state=state) |
| | | x = x + self.dropout_ffn(ffn) |
| | | |
| | | return x, state |
| New file |
| | |
| | | """Attention (time mixing) modules for RWKV block. |
| | | |
| | | Based/Modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py. |
| | | |
| | | Some variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py. |
| | | |
| | | """ # noqa |
| | | |
| | | import math |
| | | from importlib.util import find_spec |
| | | from pathlib import Path |
| | | from typing import List, Optional, Tuple, Union |
| | | |
| | | import torch |
| | | |
| | | wkv_kernel_encoder = None |
| | | wkv_kernel_decoder = None |
| | | |
| | | class WKVLinearAttentionEncoder(torch.autograd.Function): |
| | | """WKVLinearAttention function definition.""" |
| | | |
| | | @staticmethod |
| | | def forward( |
| | | ctx, |
| | | time_decay: torch.Tensor, |
| | | time_first: torch.Tensor, |
| | | key: torch.Tensor, |
| | | value: torch.tensor, |
| | | ) -> torch.Tensor: |
| | | """WKVLinearAttention function forward pass. |
| | | |
| | | Args: |
| | | time_decay: Channel-wise time decay vector. (D_att) |
| | | time_first: Channel-wise time first vector. (D_att) |
| | | key: Key tensor. (B, U, D_att) |
| | | value: Value tensor. (B, U, D_att) |
| | | |
| | | Returns: |
| | | out: Weighted Key-Value tensor. (B, U, D_att) |
| | | |
| | | """ |
| | | batch, length, dim = key.size() |
| | | |
| | | assert length <= wkv_kernel_encoder.context_size, ( |
| | | f"Cannot process key of length {length} while context_size " |
| | | f"is ({wkv_kernel_encoder.context_size}). Limit should be increased." |
| | | ) |
| | | |
| | | assert batch * dim % min(dim, 32) == 0, ( |
| | | f"batch size ({batch}) by dimension ({dim}) should be a multiple of " |
| | | f"{min(dim, 32)}" |
| | | ) |
| | | |
| | | ctx.input_dtype = key.dtype |
| | | |
| | | time_decay = -torch.exp(time_decay.float().contiguous()) |
| | | time_first = time_first.float().contiguous() |
| | | |
| | | key = key.float().contiguous() |
| | | value = value.float().contiguous() |
| | | |
| | | out = torch.empty_like(key, memory_format=torch.contiguous_format) |
| | | |
| | | wkv_kernel_encoder.forward(time_decay, time_first, key, value, out) |
| | | ctx.save_for_backward(time_decay, time_first, key, value, out) |
| | | |
| | | return out |
| | | |
| | | @staticmethod |
| | | def backward( |
| | | ctx, grad_output: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| | | """WKVLinearAttention function backward pass. |
| | | |
| | | Args: |
| | | grad_output: Output gradient. (B, U, D_att) |
| | | |
| | | Returns: |
| | | grad_time_decay: Gradient for channel-wise time decay vector. (D_att) |
| | | grad_time_first: Gradient for channel-wise time first vector. (D_att) |
| | | grad_key: Gradient for key tensor. (B, U, D_att) |
| | | grad_value: Gradient for value tensor. (B, U, D_att) |
| | | |
| | | """ |
| | | time_decay, time_first, key, value, output = ctx.saved_tensors |
| | | grad_dtype = ctx.input_dtype |
| | | |
| | | batch, _, dim = key.size() |
| | | |
| | | grad_time_decay = torch.empty( |
| | | (batch, dim), |
| | | memory_format=torch.contiguous_format, |
| | | dtype=time_decay.dtype, |
| | | device=time_decay.device, |
| | | ) |
| | | |
| | | grad_time_first = torch.empty( |
| | | (batch, dim), |
| | | memory_format=torch.contiguous_format, |
| | | dtype=time_decay.dtype, |
| | | device=time_decay.device, |
| | | ) |
| | | |
| | | grad_key = torch.empty_like(key, memory_format=torch.contiguous_format) |
| | | grad_value = torch.empty_like(value, memory_format=torch.contiguous_format) |
| | | |
| | | wkv_kernel_encoder.backward( |
| | | time_decay, |
| | | time_first, |
| | | key, |
| | | value, |
| | | output, |
| | | grad_output.contiguous(), |
| | | grad_time_decay, |
| | | grad_time_first, |
| | | grad_key, |
| | | grad_value, |
| | | ) |
| | | |
| | | grad_time_decay = torch.sum(grad_time_decay, dim=0) |
| | | grad_time_first = torch.sum(grad_time_first, dim=0) |
| | | |
| | | return ( |
| | | grad_time_decay, |
| | | grad_time_first, |
| | | grad_key, |
| | | grad_value, |
| | | ) |
| | | |
| | | class WKVLinearAttentionDecoder(torch.autograd.Function): |
| | | """WKVLinearAttention function definition.""" |
| | | |
| | | @staticmethod |
| | | def forward( |
| | | ctx, |
| | | time_decay: torch.Tensor, |
| | | time_first: torch.Tensor, |
| | | key: torch.Tensor, |
| | | value: torch.tensor, |
| | | ) -> torch.Tensor: |
| | | """WKVLinearAttention function forward pass. |
| | | |
| | | Args: |
| | | time_decay: Channel-wise time decay vector. (D_att) |
| | | time_first: Channel-wise time first vector. (D_att) |
| | | key: Key tensor. (B, U, D_att) |
| | | value: Value tensor. (B, U, D_att) |
| | | |
| | | Returns: |
| | | out: Weighted Key-Value tensor. (B, U, D_att) |
| | | |
| | | """ |
| | | batch, length, dim = key.size() |
| | | |
| | | assert length <= wkv_kernel_decoder.context_size, ( |
| | | f"Cannot process key of length {length} while context_size " |
| | | f"is ({wkv_kernel.context_size}). Limit should be increased." |
| | | ) |
| | | |
| | | assert batch * dim % min(dim, 32) == 0, ( |
| | | f"batch size ({batch}) by dimension ({dim}) should be a multiple of " |
| | | f"{min(dim, 32)}" |
| | | ) |
| | | |
| | | ctx.input_dtype = key.dtype |
| | | |
| | | time_decay = -torch.exp(time_decay.float().contiguous()) |
| | | time_first = time_first.float().contiguous() |
| | | |
| | | key = key.float().contiguous() |
| | | value = value.float().contiguous() |
| | | |
| | | out = torch.empty_like(key, memory_format=torch.contiguous_format) |
| | | |
| | | wkv_kernel_decoder.forward(time_decay, time_first, key, value, out) |
| | | ctx.save_for_backward(time_decay, time_first, key, value, out) |
| | | |
| | | return out |
| | | |
| | | @staticmethod |
| | | def backward( |
| | | ctx, grad_output: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| | | """WKVLinearAttention function backward pass. |
| | | |
| | | Args: |
| | | grad_output: Output gradient. (B, U, D_att) |
| | | |
| | | Returns: |
| | | grad_time_decay: Gradient for channel-wise time decay vector. (D_att) |
| | | grad_time_first: Gradient for channel-wise time first vector. (D_att) |
| | | grad_key: Gradient for key tensor. (B, U, D_att) |
| | | grad_value: Gradient for value tensor. (B, U, D_att) |
| | | |
| | | """ |
| | | time_decay, time_first, key, value, output = ctx.saved_tensors |
| | | grad_dtype = ctx.input_dtype |
| | | |
| | | batch, _, dim = key.size() |
| | | |
| | | grad_time_decay = torch.empty( |
| | | (batch, dim), |
| | | memory_format=torch.contiguous_format, |
| | | dtype=time_decay.dtype, |
| | | device=time_decay.device, |
| | | ) |
| | | |
| | | grad_time_first = torch.empty( |
| | | (batch, dim), |
| | | memory_format=torch.contiguous_format, |
| | | dtype=time_decay.dtype, |
| | | device=time_decay.device, |
| | | ) |
| | | |
| | | grad_key = torch.empty_like(key, memory_format=torch.contiguous_format) |
| | | grad_value = torch.empty_like(value, memory_format=torch.contiguous_format) |
| | | |
| | | wkv_kernel_decoder.backward( |
| | | time_decay, |
| | | time_first, |
| | | key, |
| | | value, |
| | | output, |
| | | grad_output.contiguous(), |
| | | grad_time_decay, |
| | | grad_time_first, |
| | | grad_key, |
| | | grad_value, |
| | | ) |
| | | |
| | | grad_time_decay = torch.sum(grad_time_decay, dim=0) |
| | | grad_time_first = torch.sum(grad_time_first, dim=0) |
| | | |
| | | return ( |
| | | grad_time_decay, |
| | | grad_time_first, |
| | | grad_key, |
| | | grad_value, |
| | | ) |
| | | |
| | | def load_encoder_wkv_kernel(context_size: int) -> None: |
| | | """Load WKV CUDA kernel. |
| | | |
| | | Args: |
| | | context_size: Context size. |
| | | |
| | | """ |
| | | from torch.utils.cpp_extension import load |
| | | |
| | | global wkv_kernel_encoder |
| | | |
| | | if wkv_kernel_encoder is not None and wkv_kernel_encoder.context_size == context_size: |
| | | return |
| | | |
| | | if find_spec("ninja") is None: |
| | | raise ImportError( |
| | | "Ninja package was not found. WKV kernel module can't be loaded " |
| | | "for training. Please, 'pip install ninja' in your environment." |
| | | ) |
| | | |
| | | if not torch.cuda.is_available(): |
| | | raise ImportError( |
| | | "CUDA is currently a requirement for WKV kernel loading. " |
| | | "Please set your devices properly and launch again." |
| | | ) |
| | | |
| | | kernel_folder = Path(__file__).resolve().parent / "cuda_encoder" |
| | | kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu"]] |
| | | |
| | | kernel_cflags = [ |
| | | "-res-usage", |
| | | "--maxrregcount 60", |
| | | "--use_fast_math", |
| | | "-O3", |
| | | "-Xptxas -O3", |
| | | f"-DTmax={context_size}", |
| | | ] |
| | | wkv_kernel_encoder = load( |
| | | name=f"encoder_wkv_{context_size}", |
| | | sources=kernel_files, |
| | | verbose=True, |
| | | extra_cuda_cflags=kernel_cflags, |
| | | ) |
| | | wkv_kernel_encoder.context_size = context_size |
| | | |
| | | def load_decoder_wkv_kernel(context_size: int) -> None: |
| | | """Load WKV CUDA kernel. |
| | | |
| | | Args: |
| | | context_size: Context size. |
| | | |
| | | """ |
| | | from torch.utils.cpp_extension import load |
| | | |
| | | global wkv_kernel_decoder |
| | | |
| | | if wkv_kernel_decoder is not None and wkv_kernel_decoder.context_size == context_size: |
| | | return |
| | | |
| | | if find_spec("ninja") is None: |
| | | raise ImportError( |
| | | "Ninja package was not found. WKV kernel module can't be loaded " |
| | | "for training. Please, 'pip install ninja' in your environment." |
| | | ) |
| | | |
| | | if not torch.cuda.is_available(): |
| | | raise ImportError( |
| | | "CUDA is currently a requirement for WKV kernel loading. " |
| | | "Please set your devices properly and launch again." |
| | | ) |
| | | |
| | | kernel_folder = Path(__file__).resolve().parent / "cuda_decoder" |
| | | kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu"]] |
| | | |
| | | kernel_cflags = [ |
| | | "-res-usage", |
| | | "--maxrregcount 60", |
| | | "--use_fast_math", |
| | | "-O3", |
| | | "-Xptxas -O3", |
| | | f"-DTmax={context_size}", |
| | | ] |
| | | wkv_kernel_decoder = load( |
| | | name=f"decoder_wkv_{context_size}", |
| | | sources=kernel_files, |
| | | verbose=True, |
| | | extra_cuda_cflags=kernel_cflags, |
| | | ) |
| | | wkv_kernel_decoder.context_size = context_size |
| | | |
| | | class SelfAttention(torch.nn.Module): |
| | | """SelfAttention module definition. |
| | | |
| | | Args: |
| | | size: Input/Output size. |
| | | attention_size: Attention hidden size. |
| | | context_size: Context size for WKV kernel. |
| | | block_id: Block index. |
| | | num_blocks: Number of blocks in the architecture. |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | size: int, |
| | | attention_size: int, |
| | | block_id: int, |
| | | dropout_rate: float, |
| | | num_blocks: int, |
| | | ) -> None: |
| | | """Construct a SelfAttention object.""" |
| | | super().__init__() |
| | | self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1)) |
| | | |
| | | self.time_decay = torch.nn.Parameter(torch.empty(attention_size)) |
| | | self.time_first = torch.nn.Parameter(torch.empty(attention_size)) |
| | | |
| | | self.time_mix_key = torch.nn.Parameter(torch.empty(1, 1, size)) |
| | | self.time_mix_value = torch.nn.Parameter(torch.empty(1, 1, size)) |
| | | self.time_mix_receptance = torch.nn.Parameter(torch.empty(1, 1, size)) |
| | | |
| | | self.proj_key = torch.nn.Linear(size, attention_size, bias=True) |
| | | self.proj_value = torch.nn.Linear(size, attention_size, bias=True) |
| | | self.proj_receptance = torch.nn.Linear(size, attention_size, bias=True) |
| | | |
| | | self.proj_output = torch.nn.Linear(attention_size, size, bias=True) |
| | | |
| | | self.block_id = block_id |
| | | |
| | | self.reset_parameters(size, attention_size, block_id, num_blocks) |
| | | self.dropout = torch.nn.Dropout(p=dropout_rate) |
| | | |
| | | def reset_parameters( |
| | | self, size: int, attention_size: int, block_id: int, num_blocks: int |
| | | ) -> None: |
| | | """Reset module parameters. |
| | | |
| | | Args: |
| | | size: Block size. |
| | | attention_size: Attention hidden size. |
| | | block_id: Block index. |
| | | num_blocks: Number of blocks in the architecture. |
| | | |
| | | """ |
| | | ratio_0_to_1 = block_id / (num_blocks - 1) |
| | | ratio_1_to_almost0 = 1.0 - (block_id / num_blocks) |
| | | |
| | | time_weight = torch.ones(1, 1, size) |
| | | |
| | | for i in range(size): |
| | | time_weight[0, 0, i] = i / size |
| | | |
| | | decay_speed = [ |
| | | -5 + 8 * (h / (attention_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1) |
| | | for h in range(attention_size) |
| | | ] |
| | | decay_speed = torch.tensor( |
| | | decay_speed, dtype=self.time_decay.dtype, device=self.time_decay.device |
| | | ) |
| | | |
| | | zigzag = ( |
| | | torch.tensor( |
| | | [(i + 1) % 3 - 1 for i in range(attention_size)], |
| | | dtype=self.time_first.dtype, |
| | | device=self.time_first.device, |
| | | ) |
| | | * 0.5 |
| | | ) |
| | | |
| | | with torch.no_grad(): |
| | | self.time_decay.data = decay_speed |
| | | self.time_first.data = torch.ones_like( |
| | | self.time_first * math.log(0.3) + zigzag |
| | | ) |
| | | |
| | | self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) |
| | | self.time_mix_value.data = ( |
| | | torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 |
| | | ) |
| | | self.time_mix_receptance.data = torch.pow( |
| | | time_weight, 0.5 * ratio_1_to_almost0 |
| | | ) |
| | | |
| | | @torch.no_grad() |
| | | def wkv_linear_attention( |
| | | self, |
| | | time_decay: torch.Tensor, |
| | | time_first: torch.Tensor, |
| | | key: torch.Tensor, |
| | | value: torch.Tensor, |
| | | state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], |
| | | ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: |
| | | """Compute WKV with state (i.e.: for inference). |
| | | |
| | | Args: |
| | | time_decay: Channel-wise time decay vector. (D_att) |
| | | time_first: Channel-wise time first vector. (D_att) |
| | | key: Key tensor. (B, 1, D_att) |
| | | value: Value tensor. (B, 1, D_att) |
| | | state: Decoder hidden states. [3 x (B, D_att)] |
| | | |
| | | Returns: |
| | | output: Weighted Key-Value. (B, 1, D_att) |
| | | state: Decoder hidden states. [3 x (B, 1, D_att)] |
| | | |
| | | """ |
| | | num_state, den_state, max_state = state |
| | | |
| | | max_for_output = torch.maximum(max_state, (time_first + key)) |
| | | |
| | | e1 = torch.exp(max_state - max_for_output) |
| | | e2 = torch.exp((time_first + key) - max_for_output) |
| | | |
| | | numerator = e1 * num_state + e2 * value |
| | | denominator = e1 * den_state + e2 |
| | | |
| | | max_for_state = torch.maximum(key, (max_state + time_decay)) |
| | | |
| | | e1 = torch.exp((max_state + time_decay) - max_for_state) |
| | | e2 = torch.exp(key - max_for_state) |
| | | |
| | | wkv = numerator / denominator |
| | | |
| | | state = [e1 * num_state + e2 * value, e1 * den_state + e2, max_for_state] |
| | | |
| | | return wkv, state |
| | | |
| | | |
| | | class DecoderSelfAttention(SelfAttention): |
| | | """SelfAttention module definition. |
| | | |
| | | Args: |
| | | size: Input/Output size. |
| | | attention_size: Attention hidden size. |
| | | context_size: Context size for WKV kernel. |
| | | block_id: Block index. |
| | | num_blocks: Number of blocks in the architecture. |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | size: int, |
| | | attention_size: int, |
| | | context_size: int, |
| | | block_id: int, |
| | | dropout_rate: float, |
| | | num_blocks: int, |
| | | ) -> None: |
| | | """Construct a SelfAttention object.""" |
| | | super().__init__( |
| | | size, |
| | | attention_size, |
| | | block_id, |
| | | dropout_rate, |
| | | num_blocks |
| | | ) |
| | | load_decoder_wkv_kernel(context_size) |
| | | |
| | | def forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | state: Optional[List[torch.Tensor]] = None, |
| | | ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: |
| | | """Compute time mixing. |
| | | |
| | | Args: |
| | | x: SelfAttention input sequences. (B, U, size) |
| | | state: Decoder hidden states. [5 x (B, 1, D_att, N)] |
| | | |
| | | Returns: |
| | | x: SelfAttention output sequences. (B, U, size) |
| | | |
| | | """ |
| | | shifted_x = ( |
| | | self.time_shift(x) if state is None else state[1][..., self.block_id] |
| | | ) |
| | | |
| | | key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key) |
| | | value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value) |
| | | receptance = x * self.time_mix_receptance + shifted_x * ( |
| | | 1 - self.time_mix_receptance |
| | | ) |
| | | |
| | | key = self.proj_key(key) |
| | | value = self.proj_value(value) |
| | | receptance = torch.sigmoid(self.proj_receptance(receptance)) |
| | | |
| | | if state is not None: |
| | | state[1][..., self.block_id] = x |
| | | |
| | | wkv, att_state = self.wkv_linear_attention( |
| | | self.time_decay, |
| | | self.time_first, |
| | | key, |
| | | value, |
| | | tuple(s[..., self.block_id] for s in state[2:]), |
| | | ) |
| | | |
| | | state[2][..., self.block_id] = att_state[0] |
| | | state[3][..., self.block_id] = att_state[1] |
| | | state[4][..., self.block_id] = att_state[2] |
| | | else: |
| | | wkv = WKVLinearAttentionDecoder.apply(self.time_decay, self.time_first, key, value) |
| | | |
| | | wkv = self.dropout(wkv) |
| | | x = self.proj_output(receptance * wkv) |
| | | |
| | | return x, state |
| | | |
| | | class EncoderSelfAttention(SelfAttention): |
| | | """SelfAttention module definition. |
| | | |
| | | Args: |
| | | size: Input/Output size. |
| | | attention_size: Attention hidden size. |
| | | context_size: Context size for WKV kernel. |
| | | block_id: Block index. |
| | | num_blocks: Number of blocks in the architecture. |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | size: int, |
| | | attention_size: int, |
| | | context_size: int, |
| | | block_id: int, |
| | | dropout_rate: float, |
| | | num_blocks: int, |
| | | ) -> None: |
| | | """Construct a SelfAttention object.""" |
| | | super().__init__( |
| | | size, |
| | | attention_size, |
| | | block_id, |
| | | dropout_rate, |
| | | num_blocks |
| | | ) |
| | | load_encoder_wkv_kernel(context_size) |
| | | |
| | | def forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | state: Optional[List[torch.Tensor]] = None, |
| | | ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: |
| | | """Compute time mixing. |
| | | |
| | | Args: |
| | | x: SelfAttention input sequences. (B, U, size) |
| | | state: Decoder hidden states. [5 x (B, 1, D_att, N)] |
| | | |
| | | Returns: |
| | | x: SelfAttention output sequences. (B, U, size) |
| | | |
| | | """ |
| | | shifted_x = ( |
| | | self.time_shift(x) if state is None else state[1][..., self.block_id] |
| | | ) |
| | | |
| | | key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key) |
| | | value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value) |
| | | receptance = x * self.time_mix_receptance + shifted_x * ( |
| | | 1 - self.time_mix_receptance |
| | | ) |
| | | |
| | | key = self.proj_key(key) |
| | | value = self.proj_value(value) |
| | | receptance = torch.sigmoid(self.proj_receptance(receptance)) |
| | | |
| | | if state is not None: |
| | | state[1][..., self.block_id] = x |
| | | |
| | | wkv, att_state = self.wkv_linear_attention( |
| | | self.time_decay, |
| | | self.time_first, |
| | | key, |
| | | value, |
| | | tuple(s[..., self.block_id] for s in state[2:]), |
| | | ) |
| | | |
| | | state[2][..., self.block_id] = att_state[0] |
| | | state[3][..., self.block_id] = att_state[1] |
| | | state[4][..., self.block_id] = att_state[2] |
| | | else: |
| | | wkv = WKVLinearAttentionEncoder.apply(self.time_decay, self.time_first, key, value) |
| | | |
| | | wkv = self.dropout(wkv) |
| | | x = self.proj_output(receptance * wkv) |
| | | |
| | | return x, state |
| | | |
| New file |
| | |
| | | """Feed-forward (channel mixing) module for RWKV block. |
| | | |
| | | Based/Modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py |
| | | |
| | | Some variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py. |
| | | |
| | | """ # noqa |
| | | |
| | | from typing import List, Optional, Tuple |
| | | |
| | | import torch |
| | | |
| | | |
| | | class FeedForward(torch.nn.Module): |
| | | """FeedForward module definition. |
| | | |
| | | Args: |
| | | size: Input/Output size. |
| | | hidden_size: Hidden size. |
| | | block_id: Block index. |
| | | num_blocks: Number of blocks in the architecture. |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, size: int, hidden_size: int, block_id: int, dropout_rate: float, num_blocks: int |
| | | ) -> None: |
| | | """Construct a FeedForward object.""" |
| | | super().__init__() |
| | | |
| | | self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1)) |
| | | |
| | | self.time_mix_key = torch.nn.Parameter(torch.empty(1, 1, size)) |
| | | self.time_mix_receptance = torch.nn.Parameter(torch.empty(1, 1, size)) |
| | | |
| | | self.proj_key = torch.nn.Linear(size, hidden_size, bias=True) |
| | | self.proj_value = torch.nn.Linear(hidden_size, size, bias=True) |
| | | self.proj_receptance = torch.nn.Linear(size, size, bias=True) |
| | | |
| | | self.block_id = block_id |
| | | |
| | | self.reset_parameters(size, block_id, num_blocks) |
| | | self.dropout = torch.nn.Dropout(p=dropout_rate) |
| | | |
| | | def reset_parameters(self, size: int, block_id: int, num_blocks: int) -> None: |
| | | """Reset module parameters. |
| | | |
| | | Args: |
| | | size: Block size. |
| | | block_id: Block index. |
| | | num_blocks: Number of blocks in the architecture. |
| | | |
| | | """ |
| | | ratio_1_to_almost0 = 1.0 - (block_id / num_blocks) |
| | | |
| | | time_weight = torch.ones(1, 1, size) |
| | | |
| | | for i in range(size): |
| | | time_weight[0, 0, i] = i / size |
| | | |
| | | with torch.no_grad(): |
| | | self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) |
| | | self.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0) |
| | | |
| | | def forward( |
| | | self, x: torch.Tensor, state: Optional[List[torch.Tensor]] = None |
| | | ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: |
| | | """Compute channel mixing. |
| | | |
| | | Args: |
| | | x: FeedForward input sequences. (B, U, size) |
| | | state: Decoder hidden state. [5 x (B, 1, size, N)] |
| | | |
| | | Returns: |
| | | x: FeedForward output sequences. (B, U, size) |
| | | state: Decoder hidden state. [5 x (B, 1, size, N)] |
| | | |
| | | """ |
| | | shifted_x = ( |
| | | self.time_shift(x) if state is None else state[0][..., self.block_id] |
| | | ) |
| | | |
| | | key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key) |
| | | receptance = x * self.time_mix_receptance + shifted_x * ( |
| | | 1 - self.time_mix_receptance |
| | | ) |
| | | |
| | | key = torch.square(torch.relu(self.proj_key(key))) |
| | | value = self.proj_value(self.dropout(key)) |
| | | receptance = torch.sigmoid(self.proj_receptance(receptance)) |
| | | |
| | | if state is not None: |
| | | state[0][..., self.block_id] = x |
| | | |
| | | x = receptance * value |
| | | |
| | | return x, state |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # -*- coding: utf-8 -*- |
| | | |
| | | # Copyright 2019 Shigeki Karita |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | """Subsampling layer definition.""" |
| | | import numpy as np |
| | | import torch |
| | | import torch.nn.functional as F |
| | | from funasr.modules.embedding import PositionalEncoding |
| | | import logging |
| | | from funasr.modules.streaming_utils.utils import sequence_mask |
| | | from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len |
| | | from typing import Optional, Tuple, Union |
| | | import math |
| | | |
| | | class TooShortUttError(Exception): |
| | | """Raised when the utt is too short for subsampling. |
| | | |
| | | Args: |
| | | message (str): Message for error catch |
| | | actual_size (int): the short size that cannot pass the subsampling |
| | | limit (int): the limit size for subsampling |
| | | |
| | | """ |
| | | |
| | | def __init__(self, message, actual_size, limit): |
| | | """Construct a TooShortUttError for error handler.""" |
| | | super().__init__(message) |
| | | self.actual_size = actual_size |
| | | self.limit = limit |
| | | |
| | | |
| | | def check_short_utt(ins, size): |
| | | """Check if the utterance is too short for subsampling.""" |
| | | if isinstance(ins, Conv2dSubsampling2) and size < 3: |
| | | return True, 3 |
| | | if isinstance(ins, Conv2dSubsampling) and size < 7: |
| | | return True, 7 |
| | | if isinstance(ins, Conv2dSubsampling6) and size < 11: |
| | | return True, 11 |
| | | if isinstance(ins, Conv2dSubsampling8) and size < 15: |
| | | return True, 15 |
| | | return False, -1 |
| | | |
| | | |
| | | class RWKVConvInput(torch.nn.Module): |
| | | """Streaming ConvInput module definition. |
| | | Args: |
| | | input_size: Input size. |
| | | conv_size: Convolution size. |
| | | subsampling_factor: Subsampling factor. |
| | | output_size: Block output dimension. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | input_size: int, |
| | | conv_size: Union[int, Tuple], |
| | | subsampling_factor: int = 4, |
| | | conv_kernel_size: int = 3, |
| | | output_size: Optional[int] = None, |
| | | ) -> None: |
| | | """Construct a ConvInput object.""" |
| | | super().__init__() |
| | | if subsampling_factor == 1: |
| | | conv_size1, conv_size2, conv_size3 = conv_size |
| | | |
| | | self.conv = torch.nn.Sequential( |
| | | torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=[1, 2], padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=[1, 2], padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size2, conv_size3, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size3, conv_size3, conv_kernel_size, stride=[1, 2], padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | ) |
| | | |
| | | output_proj = conv_size3 * ((input_size // 2) // 2) |
| | | |
| | | self.subsampling_factor = 1 |
| | | |
| | | self.stride_1 = 1 |
| | | |
| | | self.create_new_mask = self.create_new_vgg_mask |
| | | |
| | | else: |
| | | conv_size1, conv_size2, conv_size3 = conv_size |
| | | |
| | | kernel_1 = int(subsampling_factor / 2) |
| | | |
| | | self.conv = torch.nn.Sequential( |
| | | torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=[kernel_1, 2], padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=[2, 2], padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size2, conv_size3, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size3, conv_size3, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | ) |
| | | |
| | | output_proj = conv_size3 * ((input_size // 2) // 2) |
| | | |
| | | self.subsampling_factor = subsampling_factor |
| | | |
| | | self.create_new_mask = self.create_new_vgg_mask |
| | | |
| | | self.stride_1 = kernel_1 |
| | | |
| | | self.min_frame_length = 7 |
| | | |
| | | if output_size is not None: |
| | | self.output = torch.nn.Linear(output_proj, output_size) |
| | | self.output_size = output_size |
| | | else: |
| | | self.output = None |
| | | self.output_size = output_proj |
| | | |
| | | def forward( |
| | | self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor] |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Encode input sequences. |
| | | Args: |
| | | x: ConvInput input sequences. (B, T, D_feats) |
| | | mask: Mask of input sequences. (B, 1, T) |
| | | Returns: |
| | | x: ConvInput output sequences. (B, sub(T), D_out) |
| | | mask: Mask of output sequences. (B, 1, sub(T)) |
| | | """ |
| | | if mask is not None: |
| | | mask = self.create_new_mask(mask) |
| | | olens = max(mask.eq(0).sum(1)) |
| | | |
| | | b, t, f = x.size() |
| | | x = x.unsqueeze(1) # (b. 1. t. f) |
| | | |
| | | if chunk_size is not None: |
| | | max_input_length = int( |
| | | chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) )) |
| | | ) |
| | | x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x) |
| | | x = list(x) |
| | | x = torch.stack(x, dim=0) |
| | | N_chunks = max_input_length // ( chunk_size * self.subsampling_factor) |
| | | x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f) |
| | | |
| | | x = self.conv(x) |
| | | |
| | | _, c, _, f = x.size() |
| | | if chunk_size is not None: |
| | | x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:] |
| | | else: |
| | | x = x.transpose(1, 2).contiguous().view(b, -1, c * f) |
| | | |
| | | if self.output is not None: |
| | | x = self.output(x) |
| | | |
| | | return x, mask[:,:olens][:,:x.size(1)] |
| | | |
| | | def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor: |
| | | """Create a new mask for VGG output sequences. |
| | | Args: |
| | | mask: Mask of input sequences. (B, T) |
| | | Returns: |
| | | mask: Mask of output sequences. (B, sub(T)) |
| | | """ |
| | | if self.subsampling_factor > 1: |
| | | return mask[:, ::2][:, ::self.stride_1] |
| | | else: |
| | | return mask |
| | | |
| | | def get_size_before_subsampling(self, size: int) -> int: |
| | | """Return the original size before subsampling for a given size. |
| | | Args: |
| | | size: Number of frames after subsampling. |
| | | Returns: |
| | | : Number of frames before subsampling. |
| | | """ |
| | | return size * self.subsampling_factor |