zhifu gao
2023-11-22 b57b98364ff60ae0119b2e8d92471316bb4e504f
funasr/modules/rwkv_attention.py
@@ -445,7 +445,7 @@
        """
        num_state, den_state, max_state = state
        time_decay = -torch.exp(time_decay)
        max_for_output = torch.maximum(max_state, (time_first + key))
        e1 = torch.exp(max_state - max_for_output)
@@ -495,7 +495,7 @@
            dropout_rate,
            num_blocks
        )
        load_decoder_wkv_kernel(context_size)
        # load_decoder_wkv_kernel(context_size)
    def forward(
        self,
@@ -577,7 +577,7 @@
            dropout_rate,
            num_blocks
        )
        load_encoder_wkv_kernel(context_size)
        # load_encoder_wkv_kernel(context_size)
    def forward(
        self,