| | |
| | | x = self.embed_norm(x) |
| | | olens = mask.eq(0).sum(1) |
| | | |
| | | for block in self.rwkv_blocks: |
| | | x, _ = block(x) |
| | | # for streaming inference |
| | | # xs_pad = self.rwkv_infer(xs_pad) |
| | | # for training |
| | | # for block in self.rwkv_blocks: |
| | | # x, _ = block(x) |
| | | |
| | | # for streaming inference |
| | | x = self.rwkv_infer(x) |
| | | x = self.final_norm(x) |
| | | |
| | | if self.time_reduction_factor > 1: |
| | |
| | | |
| | | state = [ |
| | | torch.zeros( |
| | | (batch_size, 1, hidden_sizes[i], self.num_rwkv_blocks), |
| | | (batch_size, 1, hidden_sizes[i], self.num_blocks), |
| | | dtype=torch.float32, |
| | | device=self.device, |
| | | device=xs_pad.device, |
| | | ) |
| | | for i in range(5) |
| | | ] |
| | |
| | | for idx, block in enumerate(self.rwkv_blocks): |
| | | x_t, state = block(x_t, state=state) |
| | | xs_out.append(x_t) |
| | | xs_out = torch.stack(xs_out, dim=1) |
| | | xs_out = torch.cat(xs_out, dim=1) |
| | | return xs_out |