From 26a2a232a94c4a729733d83e8175a16e3f8db481 Mon Sep 17 00:00:00 2001 From: 游雁 <zhifu.gzf@alibaba-inc.com> Date: 星期一, 23 十月 2023 16:42:43 +0800 Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add --- funasr/models/encoder/rwkv_encoder.py | 155 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 files changed, 155 insertions(+), 0 deletions(-) diff --git a/funasr/models/encoder/rwkv_encoder.py b/funasr/models/encoder/rwkv_encoder.py new file mode 100644 index 0000000..8a33520 --- /dev/null +++ b/funasr/models/encoder/rwkv_encoder.py @@ -0,0 +1,155 @@ +"""RWKV encoder definition for Transducer models.""" + +import math +from typing import Dict, List, Optional, Tuple + +import torch + +from funasr.models.encoder.abs_encoder import AbsEncoder +from funasr.modules.rwkv import RWKV +from funasr.modules.layer_norm import LayerNorm +from funasr.modules.rwkv_subsampling import RWKVConvInput +from funasr.modules.nets_utils import make_source_mask + +class RWKVEncoder(AbsEncoder): + """RWKV encoder module. + + Based on https://arxiv.org/pdf/2305.13048.pdf. + + Args: + vocab_size: Vocabulary size. + output_size: Input/Output size. + context_size: Context size for WKV computation. + linear_size: FeedForward hidden size. + attention_size: SelfAttention hidden size. + normalization_type: Normalization layer type. + normalization_args: Normalization layer arguments. + num_blocks: Number of RWKV blocks. + embed_dropout_rate: Dropout rate for embedding layer. + att_dropout_rate: Dropout rate for the attention module. + ffn_dropout_rate: Dropout rate for the feed-forward module. + """ + + def __init__( + self, + input_size: int, + output_size: int = 512, + context_size: int = 1024, + linear_size: Optional[int] = None, + attention_size: Optional[int] = None, + num_blocks: int = 4, + att_dropout_rate: float = 0.0, + ffn_dropout_rate: float = 0.0, + dropout_rate: float = 0.0, + subsampling_factor: int =4, + time_reduction_factor: int = 1, + kernel: int = 3, + ) -> None: + """Construct a RWKVEncoder object.""" + super().__init__() + + self.embed = RWKVConvInput( + input_size, + [output_size//4, output_size//2, output_size], + subsampling_factor, + conv_kernel_size=kernel, + output_size=output_size, + ) + + self.subsampling_factor = subsampling_factor + + linear_size = output_size * 4 if linear_size is None else linear_size + attention_size = output_size if attention_size is None else attention_size + + self.rwkv_blocks = torch.nn.ModuleList( + [ + RWKV( + output_size, + linear_size, + attention_size, + context_size, + block_id, + num_blocks, + att_dropout_rate=att_dropout_rate, + ffn_dropout_rate=ffn_dropout_rate, + dropout_rate=dropout_rate, + ) + for block_id in range(num_blocks) + ] + ) + + self.embed_norm = LayerNorm(output_size) + self.final_norm = LayerNorm(output_size) + + self._output_size = output_size + self.context_size = context_size + + self.num_blocks = num_blocks + self.time_reduction_factor = time_reduction_factor + + def output_size(self) -> int: + return self._output_size + + def forward(self, x: torch.Tensor, x_len) -> torch.Tensor: + """Encode source label sequences. + + Args: + x: Encoder input sequences. (B, L) + + Returns: + out: Encoder output sequences. (B, U, D) + + """ + _, length, _ = x.size() + + assert ( + length <= self.context_size * self.subsampling_factor + ), "Context size is too short for current length: %d versus %d" % ( + length, + self.context_size * self.subsampling_factor, + ) + mask = make_source_mask(x_len).to(x.device) + x, mask = self.embed(x, mask, None) + x = self.embed_norm(x) + olens = mask.eq(0).sum(1) + + for block in self.rwkv_blocks: + x, _ = block(x) + # for streaming inference + # xs_pad = self.rwkv_infer(xs_pad) + + x = self.final_norm(x) + + if self.time_reduction_factor > 1: + x = x[:,::self.time_reduction_factor,:] + olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 + + return x, olens, None + + def rwkv_infer(self, xs_pad): + + batch_size = xs_pad.shape[0] + + hidden_sizes = [ + self._output_size for i in range(5) + ] + + state = [ + torch.zeros( + (batch_size, 1, hidden_sizes[i], self.num_rwkv_blocks), + dtype=torch.float32, + device=self.device, + ) + for i in range(5) + ] + + state[4] -= 1e-30 + + xs_out = [] + for t in range(xs_pad.shape[1]): + x_t = xs_pad[:,t,:] + for idx, block in enumerate(self.rwkv_blocks): + x_t, state = block(x_t, state=state) + xs_out.append(x_t) + xs_out = torch.stack(xs_out, dim=1) + return xs_out -- Gitblit v1.9.1