| | |
| | | att_dropout_rate: float = 0.0, |
| | | ffn_dropout_rate: float = 0.0, |
| | | dropout_rate: float = 0.0, |
| | | subsampling_factor: int =4, |
| | | subsampling_factor: int = 4, |
| | | time_reduction_factor: int = 1, |
| | | kernel: int = 3, |
| | | **kwargs, |
| | |
| | | |
| | | self.embed = RWKVConvInput( |
| | | input_size, |
| | | [output_size//4, output_size//2, output_size], |
| | | [output_size // 4, output_size // 2, output_size], |
| | | subsampling_factor, |
| | | conv_kernel_size=kernel, |
| | | output_size=output_size, |
| | |
| | | |
| | | linear_size = output_size * 4 if linear_size is None else linear_size |
| | | attention_size = output_size if attention_size is None else attention_size |
| | | |
| | | |
| | | self.rwkv_blocks = torch.nn.ModuleList( |
| | | [ |
| | | RWKV( |
| | |
| | | x, _ = block(x) |
| | | else: |
| | | x = self.rwkv_infer(x) |
| | | |
| | | |
| | | x = self.final_norm(x) |
| | | |
| | | if self.time_reduction_factor > 1: |
| | | x = x[:,::self.time_reduction_factor,:] |
| | | olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 |
| | | x = x[:, :: self.time_reduction_factor, :] |
| | | olens = torch.floor_divide(olens - 1, self.time_reduction_factor) + 1 |
| | | |
| | | return x, olens, None |
| | | |
| | |
| | | |
| | | batch_size = xs_pad.shape[0] |
| | | |
| | | hidden_sizes = [ |
| | | self._output_size for i in range(5) |
| | | ] |
| | | hidden_sizes = [self._output_size for i in range(5)] |
| | | |
| | | state = [ |
| | | torch.zeros( |
| | |
| | | |
| | | xs_out = [] |
| | | for t in range(xs_pad.shape[1]): |
| | | x_t = xs_pad[:,t,:] |
| | | x_t = xs_pad[:, t, :] |
| | | for idx, block in enumerate(self.rwkv_blocks): |
| | | x_t, state = block(x_t, state=state) |
| | | xs_out.append(x_t) |