funasr/modules/rwkv_attention.py
@@ -445,7 +445,7 @@ """ num_state, den_state, max_state = state time_decay = -torch.exp(time_decay) max_for_output = torch.maximum(max_state, (time_first + key)) e1 = torch.exp(max_state - max_for_output) @@ -495,7 +495,7 @@ dropout_rate, num_blocks ) load_decoder_wkv_kernel(context_size) # load_decoder_wkv_kernel(context_size) def forward( self, @@ -577,7 +577,7 @@ dropout_rate, num_blocks ) load_encoder_wkv_kernel(context_size) # load_encoder_wkv_kernel(context_size) def forward( self,