From 4e0404e04ed890717ead276e52c927a820326ec1 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 01 十一月 2023 16:47:13 +0800
Subject: [PATCH] fix rwkv infer bugs

---
 funasr/bin/asr_infer.py                 |    1 
 funasr/modules/cuda_encoder/wkv_op.cpp  |   37 +++++
 funasr/modules/cuda_encoder/wkv_cuda.cu |  135 +++++++++++++++++++
 funasr/modules/rwkv_attention.py        |    6 
 funasr/modules/cuda_decoder/wkv_cuda.cu |  135 +++++++++++++++++++
 funasr/models/encoder/rwkv_encoder.py   |   15 +-
 funasr/modules/cuda_decoder/wkv_op.cpp  |   37 +++++
 7 files changed, 355 insertions(+), 11 deletions(-)

diff --git a/funasr/bin/asr_infer.py b/funasr/bin/asr_infer.py
index 4648fb3..7015eb8 100644
--- a/funasr/bin/asr_infer.py
+++ b/funasr/bin/asr_infer.py
@@ -1605,7 +1605,6 @@
         feats_lengths = to_device(feats_lengths, device=self.device)
 
         enc_out, _, _ = self.asr_model.encoder(feats, feats_lengths)
-
         nbest_hyps = self.beam_search(enc_out[0])
 
         return nbest_hyps
diff --git a/funasr/models/encoder/rwkv_encoder.py b/funasr/models/encoder/rwkv_encoder.py
index 8a33520..40151bf 100644
--- a/funasr/models/encoder/rwkv_encoder.py
+++ b/funasr/models/encoder/rwkv_encoder.py
@@ -113,11 +113,12 @@
         x = self.embed_norm(x)
         olens = mask.eq(0).sum(1)
 
-        for block in self.rwkv_blocks:
-            x, _ = block(x)
-        # for streaming inference
-        # xs_pad = self.rwkv_infer(xs_pad)
+        # for training
+        # for block in self.rwkv_blocks:
+        #     x, _ = block(x)
 
+        # for streaming inference
+        x = self.rwkv_infer(x)
         x = self.final_norm(x)
 
         if self.time_reduction_factor > 1:
@@ -136,9 +137,9 @@
 
         state = [
             torch.zeros(
-                (batch_size, 1, hidden_sizes[i], self.num_rwkv_blocks),
+                (batch_size, 1, hidden_sizes[i], self.num_blocks),
                 dtype=torch.float32,
-                device=self.device,
+                device=xs_pad.device,
             )
             for i in range(5)
         ]
@@ -151,5 +152,5 @@
             for idx, block in enumerate(self.rwkv_blocks):
                 x_t, state = block(x_t, state=state)
             xs_out.append(x_t)
-        xs_out = torch.stack(xs_out, dim=1)
+        xs_out = torch.cat(xs_out, dim=1)
         return xs_out
diff --git a/funasr/modules/cuda_decoder/wkv_cuda.cu b/funasr/modules/cuda_decoder/wkv_cuda.cu
new file mode 100644
index 0000000..1dcbc71
--- /dev/null
+++ b/funasr/modules/cuda_decoder/wkv_cuda.cu
@@ -0,0 +1,135 @@
+// Copied from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/cuda/wkv_cuda.cu
+
+#include <stdio.h>
+#include <assert.h>
+
+#define MIN_VALUE (-1e38)
+
+template <typename F>
+__global__ void kernel_forward(const int B, const int T, const int C,
+                               const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
+                               F *__restrict__ const _y) {
+    const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    const int _b = idx / C;
+    const int _c = idx % C;
+    const int _offset = _b * T * C + _c;
+
+    F u = _u[_c];
+    F w = _w[_c];
+    const F *__restrict__ const k = _k + _offset;
+    const F *__restrict__ const v = _v + _offset;
+    F *__restrict__ const y = _y + _offset;
+
+    // aa and bb are running sums divided by exp(pp) (to avoid overflow)
+    F aa = 0, bb = 0, pp = MIN_VALUE;
+    for (int i = 0; i < T; i++) {
+        const int ii = i * C;
+        const F kk = k[ii];
+        const F vv = v[ii];
+
+        F ww = u + kk;
+        F p = max(pp, ww);
+        F e1 = exp(pp - p);
+        F e2 = exp(ww - p);
+        y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
+
+        ww = w + pp;
+        p = max(ww, kk);
+        e1 = exp(ww - p);
+        e2 = exp(kk - p);
+        aa = e1 * aa + e2 * vv;
+        bb = e1 * bb + e2;
+        pp = p;
+    }
+}
+
+template <typename F>
+__global__ void kernel_backward(const int B, const int T, const int C,
+                                const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
+                                const F *__restrict__ const _y, const F *__restrict__ const _gy,
+                                F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
+    const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    const int _b = idx / C;
+    const int _c = idx % C;
+    const int _offset = _b * T * C + _c;
+
+    F u = _u[_c];
+    F w = _w[_c];
+    const F *__restrict__ const k = _k + _offset;
+    const F *__restrict__ const v = _v + _offset;
+    const F *__restrict__ const y = _y + _offset;
+    const F *__restrict__ const gy = _gy + _offset;
+    F *__restrict__ const gk = _gk + _offset;
+    F *__restrict__ const gv = _gv + _offset;
+
+    F q[Tmax], r[Tmax];
+
+    F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
+    for (int i = 0; i < T; i++) {
+        const int ii = i * C;
+        const F kk = k[ii];
+        const F vv = v[ii];
+        const F yy = y[ii];
+
+        F ww = u + kk;
+        F p = max(pp, ww);
+        F e1 = exp(pp - p);
+        F e2 = exp(ww - p);
+        const F qq = gy[ii] / (e1 * bb + e2);
+        gw += (ga - gb * yy) * e1 * qq;
+        gu += (vv - yy) * e2 * qq;
+        q[i] = qq;
+        r[i] = ww - p;
+
+        ww = w + pp;
+        p = max(ww, kk);
+        e1 = exp(ww - p);
+        e2 = exp(kk - p);
+        ga = e1 * (aa + ga);
+        gb = e1 * (bb + gb);
+        aa = e1 * aa + e2 * vv;
+        bb = e1 * bb + e2;
+        pp = p;
+    }
+    const int _offsetBC = _b * C + _c;
+    _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
+    _gu[_offsetBC] = gu;
+
+    aa = 0, bb = 0, pp = MIN_VALUE;
+    for (int i = T - 1; i >= 0; i--) {
+        const int ii = i * C;
+        const F kk = k[ii];
+        const F vv = v[ii];
+        const F yy = y[ii];
+        const F qq = q[i];
+        const F rr = r[i];
+
+        F e1 = qq * exp(rr);
+        F e2 = exp(kk + pp);
+        gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
+        gv[ii] = e1 + e2 * aa;
+
+        const F ww = w + pp;
+        const F www = rr - u - kk;
+        const F p = max(ww, www);
+        e1 = exp(ww - p);
+        e2 = qq * exp(www - p);
+        aa = e1 * aa + e2;
+        bb = e1 * bb - e2 * yy;
+        pp = p;
+    }
+}
+
+void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
+    dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+    assert(B * C % threadsPerBlock.x == 0);
+    dim3 numBlocks(B * C / threadsPerBlock.x);
+    kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
+}
+
+void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
+    dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+    assert(B * C % threadsPerBlock.x == 0);
+    dim3 numBlocks(B * C / threadsPerBlock.x);
+    kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
+}
diff --git a/funasr/modules/cuda_decoder/wkv_op.cpp b/funasr/modules/cuda_decoder/wkv_op.cpp
new file mode 100644
index 0000000..1024219
--- /dev/null
+++ b/funasr/modules/cuda_decoder/wkv_op.cpp
@@ -0,0 +1,37 @@
+/*
+ *    Bsed on https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/cuda/wkv_op.cpp
+   Function signatures were modified based on https://github.com/huggingface/transformers/blob/main/src/transformers/kernels/rwkv/wkv_op.cpp
+
+ */
+
+#include <torch/extension.h>
+
+void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
+
+void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
+
+void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
+  const int B = k.size(0);
+  const int T = k.size(1);
+  const int C = k.size(2);
+
+  cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
+}
+
+void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
+  const int B = k.size(0);
+  const int T = k.size(1);
+  const int C = k.size(2);
+
+  cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+    m.def("forward", &forward, "wkv forward");
+    m.def("backward", &backward, "wkv backward");
+}
+
+TORCH_LIBRARY(wkv_decoder, m) {
+    m.def("forward", forward);
+    m.def("backward", backward);
+}
diff --git a/funasr/modules/cuda_encoder/wkv_cuda.cu b/funasr/modules/cuda_encoder/wkv_cuda.cu
new file mode 100644
index 0000000..1dcbc71
--- /dev/null
+++ b/funasr/modules/cuda_encoder/wkv_cuda.cu
@@ -0,0 +1,135 @@
+// Copied from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/cuda/wkv_cuda.cu
+
+#include <stdio.h>
+#include <assert.h>
+
+#define MIN_VALUE (-1e38)
+
+template <typename F>
+__global__ void kernel_forward(const int B, const int T, const int C,
+                               const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
+                               F *__restrict__ const _y) {
+    const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    const int _b = idx / C;
+    const int _c = idx % C;
+    const int _offset = _b * T * C + _c;
+
+    F u = _u[_c];
+    F w = _w[_c];
+    const F *__restrict__ const k = _k + _offset;
+    const F *__restrict__ const v = _v + _offset;
+    F *__restrict__ const y = _y + _offset;
+
+    // aa and bb are running sums divided by exp(pp) (to avoid overflow)
+    F aa = 0, bb = 0, pp = MIN_VALUE;
+    for (int i = 0; i < T; i++) {
+        const int ii = i * C;
+        const F kk = k[ii];
+        const F vv = v[ii];
+
+        F ww = u + kk;
+        F p = max(pp, ww);
+        F e1 = exp(pp - p);
+        F e2 = exp(ww - p);
+        y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2);
+
+        ww = w + pp;
+        p = max(ww, kk);
+        e1 = exp(ww - p);
+        e2 = exp(kk - p);
+        aa = e1 * aa + e2 * vv;
+        bb = e1 * bb + e2;
+        pp = p;
+    }
+}
+
+template <typename F>
+__global__ void kernel_backward(const int B, const int T, const int C,
+                                const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
+                                const F *__restrict__ const _y, const F *__restrict__ const _gy,
+                                F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
+    const int idx = blockIdx.x * blockDim.x + threadIdx.x;
+    const int _b = idx / C;
+    const int _c = idx % C;
+    const int _offset = _b * T * C + _c;
+
+    F u = _u[_c];
+    F w = _w[_c];
+    const F *__restrict__ const k = _k + _offset;
+    const F *__restrict__ const v = _v + _offset;
+    const F *__restrict__ const y = _y + _offset;
+    const F *__restrict__ const gy = _gy + _offset;
+    F *__restrict__ const gk = _gk + _offset;
+    F *__restrict__ const gv = _gv + _offset;
+
+    F q[Tmax], r[Tmax];
+
+    F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE;
+    for (int i = 0; i < T; i++) {
+        const int ii = i * C;
+        const F kk = k[ii];
+        const F vv = v[ii];
+        const F yy = y[ii];
+
+        F ww = u + kk;
+        F p = max(pp, ww);
+        F e1 = exp(pp - p);
+        F e2 = exp(ww - p);
+        const F qq = gy[ii] / (e1 * bb + e2);
+        gw += (ga - gb * yy) * e1 * qq;
+        gu += (vv - yy) * e2 * qq;
+        q[i] = qq;
+        r[i] = ww - p;
+
+        ww = w + pp;
+        p = max(ww, kk);
+        e1 = exp(ww - p);
+        e2 = exp(kk - p);
+        ga = e1 * (aa + ga);
+        gb = e1 * (bb + gb);
+        aa = e1 * aa + e2 * vv;
+        bb = e1 * bb + e2;
+        pp = p;
+    }
+    const int _offsetBC = _b * C + _c;
+    _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward()
+    _gu[_offsetBC] = gu;
+
+    aa = 0, bb = 0, pp = MIN_VALUE;
+    for (int i = T - 1; i >= 0; i--) {
+        const int ii = i * C;
+        const F kk = k[ii];
+        const F vv = v[ii];
+        const F yy = y[ii];
+        const F qq = q[i];
+        const F rr = r[i];
+
+        F e1 = qq * exp(rr);
+        F e2 = exp(kk + pp);
+        gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb);
+        gv[ii] = e1 + e2 * aa;
+
+        const F ww = w + pp;
+        const F www = rr - u - kk;
+        const F p = max(ww, www);
+        e1 = exp(ww - p);
+        e2 = qq * exp(www - p);
+        aa = e1 * aa + e2;
+        bb = e1 * bb - e2 * yy;
+        pp = p;
+    }
+}
+
+void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
+    dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+    assert(B * C % threadsPerBlock.x == 0);
+    dim3 numBlocks(B * C / threadsPerBlock.x);
+    kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
+}
+
+void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) {
+    dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
+    assert(B * C % threadsPerBlock.x == 0);
+    dim3 numBlocks(B * C / threadsPerBlock.x);
+    kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv);
+}
diff --git a/funasr/modules/cuda_encoder/wkv_op.cpp b/funasr/modules/cuda_encoder/wkv_op.cpp
new file mode 100644
index 0000000..16f3646
--- /dev/null
+++ b/funasr/modules/cuda_encoder/wkv_op.cpp
@@ -0,0 +1,37 @@
+/*
+ *    Bsed on https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/cuda/wkv_op.cpp
+   Function signatures were modified based on https://github.com/huggingface/transformers/blob/main/src/transformers/kernels/rwkv/wkv_op.cpp
+
+ */
+
+#include <torch/extension.h>
+
+void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
+
+void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv);
+
+void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
+  const int B = k.size(0);
+  const int T = k.size(1);
+  const int C = k.size(2);
+
+  cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
+}
+
+void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
+  const int B = k.size(0);
+  const int T = k.size(1);
+  const int C = k.size(2);
+
+  cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
+}
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+    m.def("forward", &forward, "wkv forward");
+    m.def("backward", &backward, "wkv backward");
+}
+
+TORCH_LIBRARY(wkv_encoder, m) {
+    m.def("forward", forward);
+    m.def("backward", backward);
+}
diff --git a/funasr/modules/rwkv_attention.py b/funasr/modules/rwkv_attention.py
index f0c7da3..5384fb9 100644
--- a/funasr/modules/rwkv_attention.py
+++ b/funasr/modules/rwkv_attention.py
@@ -445,7 +445,7 @@
 
         """
         num_state, den_state, max_state = state
-
+        time_decay = -torch.exp(time_decay)
         max_for_output = torch.maximum(max_state, (time_first + key))
 
         e1 = torch.exp(max_state - max_for_output)
@@ -495,7 +495,7 @@
             dropout_rate,
             num_blocks
         )
-        load_decoder_wkv_kernel(context_size)
+        # load_decoder_wkv_kernel(context_size)
 
     def forward(
         self,
@@ -577,7 +577,7 @@
             dropout_rate,
             num_blocks
         )
-        load_encoder_wkv_kernel(context_size)
+        # load_encoder_wkv_kernel(context_size)
 
     def forward(
         self,

--
Gitblit v1.9.1