aky15
2023-11-01 4e0404e04ed890717ead276e52c927a820326ec1
funasr/modules/rwkv_attention.py
@@ -445,7 +445,7 @@
        """
        num_state, den_state, max_state = state
        time_decay = -torch.exp(time_decay)
        max_for_output = torch.maximum(max_state, (time_first + key))
        e1 = torch.exp(max_state - max_for_output)
@@ -495,7 +495,7 @@
            dropout_rate,
            num_blocks
        )
        load_decoder_wkv_kernel(context_size)
        # load_decoder_wkv_kernel(context_size)
    def forward(
        self,
@@ -577,7 +577,7 @@
            dropout_rate,
            num_blocks
        )
        load_encoder_wkv_kernel(context_size)
        # load_encoder_wkv_kernel(context_size)
    def forward(
        self,