| | |
| | | state: Decoder hidden state. [5 x (B, 1, size, N)] |
| | | |
| | | """ |
| | | shifted_x = ( |
| | | self.time_shift(x) if state is None else state[0][..., self.block_id] |
| | | ) |
| | | shifted_x = self.time_shift(x) if state is None else state[0][..., self.block_id] |
| | | |
| | | key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key) |
| | | receptance = x * self.time_mix_receptance + shifted_x * ( |
| | | 1 - self.time_mix_receptance |
| | | ) |
| | | receptance = x * self.time_mix_receptance + shifted_x * (1 - self.time_mix_receptance) |
| | | |
| | | key = torch.square(torch.relu(self.proj_key(key))) |
| | | value = self.proj_value(self.dropout(key)) |