| | |
| | | """RWKV encoder definition for Transducer models.""" |
| | | |
| | | import math |
| | | from typing import Dict, List, Optional, Tuple |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import torch |
| | | from typing import Dict, List, Optional, Tuple |
| | | |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.register import tables |
| | | from funasr.models.rwkv_bat.rwkv import RWKV |
| | | from funasr.models.transformer.layer_norm import LayerNorm |
| | | from funasr.models.rwkv_bat.rwkv_subsampling import RWKVConvInput |
| | | from funasr.models.transformer.utils.nets_utils import make_source_mask |
| | | from funasr.models.rwkv_bat.rwkv_subsampling import RWKVConvInput |
| | | |
| | | class RWKVEncoder(AbsEncoder): |
| | | |
| | | @tables.register("encoder_classes", "RWKVEncoder") |
| | | class RWKVEncoder(torch.nn.Module): |
| | | """RWKV encoder module. |
| | | |
| | | Based on https://arxiv.org/pdf/2305.13048.pdf. |
| | |
| | | subsampling_factor: int =4, |
| | | time_reduction_factor: int = 1, |
| | | kernel: int = 3, |
| | | **kwargs, |
| | | ) -> None: |
| | | """Construct a RWKVEncoder object.""" |
| | | super().__init__() |
| | |
| | | x = self.embed_norm(x) |
| | | olens = mask.eq(0).sum(1) |
| | | |
| | | # for training |
| | | # for block in self.rwkv_blocks: |
| | | # x, _ = block(x) |
| | | |
| | | # for streaming inference |
| | | x = self.rwkv_infer(x) |
| | | if self.training: |
| | | for block in self.rwkv_blocks: |
| | | x, _ = block(x) |
| | | else: |
| | | x = self.rwkv_infer(x) |
| | | |
| | | x = self.final_norm(x) |
| | | |
| | | if self.time_reduction_factor > 1: |