From 26a2a232a94c4a729733d83e8175a16e3f8db481 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 23 十月 2023 16:42:43 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/models/encoder/rwkv_encoder.py |  155 +++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 155 insertions(+), 0 deletions(-)

diff --git a/funasr/models/encoder/rwkv_encoder.py b/funasr/models/encoder/rwkv_encoder.py
new file mode 100644
index 0000000..8a33520
--- /dev/null
+++ b/funasr/models/encoder/rwkv_encoder.py
@@ -0,0 +1,155 @@
+"""RWKV encoder definition for Transducer models."""
+
+import math
+from typing import Dict, List, Optional, Tuple
+
+import torch
+
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.modules.rwkv import RWKV
+from funasr.modules.layer_norm import LayerNorm
+from funasr.modules.rwkv_subsampling import RWKVConvInput
+from funasr.modules.nets_utils import make_source_mask
+
+class RWKVEncoder(AbsEncoder):
+    """RWKV encoder module.
+
+    Based on https://arxiv.org/pdf/2305.13048.pdf.
+
+    Args:
+        vocab_size: Vocabulary size.
+        output_size: Input/Output size.
+        context_size: Context size for WKV computation.
+        linear_size: FeedForward hidden size.
+        attention_size: SelfAttention hidden size.
+        normalization_type: Normalization layer type.
+        normalization_args: Normalization layer arguments.
+        num_blocks: Number of RWKV blocks.
+        embed_dropout_rate: Dropout rate for embedding layer.
+        att_dropout_rate: Dropout rate for the attention module.
+        ffn_dropout_rate: Dropout rate for the feed-forward module.
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        output_size: int = 512,
+        context_size: int = 1024,
+        linear_size: Optional[int] = None,
+        attention_size: Optional[int] = None,
+        num_blocks: int = 4,
+        att_dropout_rate: float = 0.0,
+        ffn_dropout_rate: float = 0.0,
+        dropout_rate: float = 0.0,
+        subsampling_factor: int =4,
+        time_reduction_factor: int = 1,
+        kernel: int = 3,
+    ) -> None:
+        """Construct a RWKVEncoder object."""
+        super().__init__()
+
+        self.embed = RWKVConvInput(
+            input_size,
+            [output_size//4, output_size//2, output_size],
+            subsampling_factor,
+            conv_kernel_size=kernel,
+            output_size=output_size,
+        )
+
+        self.subsampling_factor = subsampling_factor
+
+        linear_size = output_size * 4 if linear_size is None else linear_size
+        attention_size = output_size if attention_size is None else attention_size
+        
+        self.rwkv_blocks = torch.nn.ModuleList(
+            [
+                RWKV(
+                    output_size,
+                    linear_size,
+                    attention_size,
+                    context_size,
+                    block_id,
+                    num_blocks,
+                    att_dropout_rate=att_dropout_rate,
+                    ffn_dropout_rate=ffn_dropout_rate,
+                    dropout_rate=dropout_rate,
+                )
+                for block_id in range(num_blocks)
+            ]
+        )
+
+        self.embed_norm = LayerNorm(output_size)
+        self.final_norm = LayerNorm(output_size)
+
+        self._output_size = output_size
+        self.context_size = context_size
+
+        self.num_blocks = num_blocks
+        self.time_reduction_factor = time_reduction_factor
+
+    def output_size(self) -> int:
+        return self._output_size
+
+    def forward(self, x: torch.Tensor, x_len) -> torch.Tensor:
+        """Encode source label sequences.
+
+        Args:
+            x: Encoder input sequences. (B, L)
+
+        Returns:
+            out: Encoder output sequences. (B, U, D)
+
+        """
+        _, length, _ = x.size()
+
+        assert (
+            length <= self.context_size * self.subsampling_factor
+        ), "Context size is too short for current length: %d versus %d" % (
+            length,
+            self.context_size * self.subsampling_factor,
+        )
+        mask = make_source_mask(x_len).to(x.device)
+        x, mask = self.embed(x, mask, None)
+        x = self.embed_norm(x)
+        olens = mask.eq(0).sum(1)
+
+        for block in self.rwkv_blocks:
+            x, _ = block(x)
+        # for streaming inference
+        # xs_pad = self.rwkv_infer(xs_pad)
+
+        x = self.final_norm(x)
+
+        if self.time_reduction_factor > 1:
+            x = x[:,::self.time_reduction_factor,:]
+            olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+        return x, olens, None
+
+    def rwkv_infer(self, xs_pad):
+
+        batch_size = xs_pad.shape[0]
+
+        hidden_sizes = [
+            self._output_size for i in range(5)
+        ]
+
+        state = [
+            torch.zeros(
+                (batch_size, 1, hidden_sizes[i], self.num_rwkv_blocks),
+                dtype=torch.float32,
+                device=self.device,
+            )
+            for i in range(5)
+        ]
+
+        state[4] -= 1e-30
+
+        xs_out = []
+        for t in range(xs_pad.shape[1]):
+            x_t = xs_pad[:,t,:]
+            for idx, block in enumerate(self.rwkv_blocks):
+                x_t, state = block(x_t, state=state)
+            xs_out.append(x_t)
+        xs_out = torch.stack(xs_out, dim=1)
+        return xs_out

--
Gitblit v1.9.1