| | |
| | | """Feed-forward (channel mixing) module for RWKV block. |
| | | |
| | | Based/Modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py |
| | | |
| | | Some variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py. |
| | | |
| | | """ # noqa |
| | | |
| | | from typing import List, Optional, Tuple |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import torch |
| | | from typing import List, Optional, Tuple |
| | | |
| | | |
| | | class FeedForward(torch.nn.Module): |
| | |
| | | state: Decoder hidden state. [5 x (B, 1, size, N)] |
| | | |
| | | """ |
| | | shifted_x = ( |
| | | self.time_shift(x) if state is None else state[0][..., self.block_id] |
| | | ) |
| | | shifted_x = self.time_shift(x) if state is None else state[0][..., self.block_id] |
| | | |
| | | key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key) |
| | | receptance = x * self.time_mix_receptance + shifted_x * ( |
| | | 1 - self.time_mix_receptance |
| | | ) |
| | | receptance = x * self.time_mix_receptance + shifted_x * (1 - self.time_mix_receptance) |
| | | |
| | | key = torch.square(torch.relu(self.proj_key(key))) |
| | | value = self.proj_value(self.dropout(key)) |