| | |
| | | feats_lengths = to_device(feats_lengths, device=self.device) |
| | | |
| | | enc_out, _, _ = self.asr_model.encoder(feats, feats_lengths) |
| | | |
| | | nbest_hyps = self.beam_search(enc_out[0]) |
| | | |
| | | return nbest_hyps |
| | |
| | | x = self.embed_norm(x) |
| | | olens = mask.eq(0).sum(1) |
| | | |
| | | for block in self.rwkv_blocks: |
| | | x, _ = block(x) |
| | | # for streaming inference |
| | | # xs_pad = self.rwkv_infer(xs_pad) |
| | | # for training |
| | | # for block in self.rwkv_blocks: |
| | | # x, _ = block(x) |
| | | |
| | | # for streaming inference |
| | | x = self.rwkv_infer(x) |
| | | x = self.final_norm(x) |
| | | |
| | | if self.time_reduction_factor > 1: |
| | |
| | | |
| | | state = [ |
| | | torch.zeros( |
| | | (batch_size, 1, hidden_sizes[i], self.num_rwkv_blocks), |
| | | (batch_size, 1, hidden_sizes[i], self.num_blocks), |
| | | dtype=torch.float32, |
| | | device=self.device, |
| | | device=xs_pad.device, |
| | | ) |
| | | for i in range(5) |
| | | ] |
| | |
| | | for idx, block in enumerate(self.rwkv_blocks): |
| | | x_t, state = block(x_t, state=state) |
| | | xs_out.append(x_t) |
| | | xs_out = torch.stack(xs_out, dim=1) |
| | | xs_out = torch.cat(xs_out, dim=1) |
| | | return xs_out |
| New file |
| | |
| | | // Copied from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/cuda/wkv_cuda.cu |
| | | |
| | | #include <stdio.h> |
| | | #include <assert.h> |
| | | |
| | | #define MIN_VALUE (-1e38) |
| | | |
| | | template <typename F> |
| | | __global__ void kernel_forward(const int B, const int T, const int C, |
| | | const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, |
| | | F *__restrict__ const _y) { |
| | | const int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| | | const int _b = idx / C; |
| | | const int _c = idx % C; |
| | | const int _offset = _b * T * C + _c; |
| | | |
| | | F u = _u[_c]; |
| | | F w = _w[_c]; |
| | | const F *__restrict__ const k = _k + _offset; |
| | | const F *__restrict__ const v = _v + _offset; |
| | | F *__restrict__ const y = _y + _offset; |
| | | |
| | | // aa and bb are running sums divided by exp(pp) (to avoid overflow) |
| | | F aa = 0, bb = 0, pp = MIN_VALUE; |
| | | for (int i = 0; i < T; i++) { |
| | | const int ii = i * C; |
| | | const F kk = k[ii]; |
| | | const F vv = v[ii]; |
| | | |
| | | F ww = u + kk; |
| | | F p = max(pp, ww); |
| | | F e1 = exp(pp - p); |
| | | F e2 = exp(ww - p); |
| | | y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2); |
| | | |
| | | ww = w + pp; |
| | | p = max(ww, kk); |
| | | e1 = exp(ww - p); |
| | | e2 = exp(kk - p); |
| | | aa = e1 * aa + e2 * vv; |
| | | bb = e1 * bb + e2; |
| | | pp = p; |
| | | } |
| | | } |
| | | |
| | | template <typename F> |
| | | __global__ void kernel_backward(const int B, const int T, const int C, |
| | | const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, |
| | | const F *__restrict__ const _y, const F *__restrict__ const _gy, |
| | | F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { |
| | | const int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| | | const int _b = idx / C; |
| | | const int _c = idx % C; |
| | | const int _offset = _b * T * C + _c; |
| | | |
| | | F u = _u[_c]; |
| | | F w = _w[_c]; |
| | | const F *__restrict__ const k = _k + _offset; |
| | | const F *__restrict__ const v = _v + _offset; |
| | | const F *__restrict__ const y = _y + _offset; |
| | | const F *__restrict__ const gy = _gy + _offset; |
| | | F *__restrict__ const gk = _gk + _offset; |
| | | F *__restrict__ const gv = _gv + _offset; |
| | | |
| | | F q[Tmax], r[Tmax]; |
| | | |
| | | F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE; |
| | | for (int i = 0; i < T; i++) { |
| | | const int ii = i * C; |
| | | const F kk = k[ii]; |
| | | const F vv = v[ii]; |
| | | const F yy = y[ii]; |
| | | |
| | | F ww = u + kk; |
| | | F p = max(pp, ww); |
| | | F e1 = exp(pp - p); |
| | | F e2 = exp(ww - p); |
| | | const F qq = gy[ii] / (e1 * bb + e2); |
| | | gw += (ga - gb * yy) * e1 * qq; |
| | | gu += (vv - yy) * e2 * qq; |
| | | q[i] = qq; |
| | | r[i] = ww - p; |
| | | |
| | | ww = w + pp; |
| | | p = max(ww, kk); |
| | | e1 = exp(ww - p); |
| | | e2 = exp(kk - p); |
| | | ga = e1 * (aa + ga); |
| | | gb = e1 * (bb + gb); |
| | | aa = e1 * aa + e2 * vv; |
| | | bb = e1 * bb + e2; |
| | | pp = p; |
| | | } |
| | | const int _offsetBC = _b * C + _c; |
| | | _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward() |
| | | _gu[_offsetBC] = gu; |
| | | |
| | | aa = 0, bb = 0, pp = MIN_VALUE; |
| | | for (int i = T - 1; i >= 0; i--) { |
| | | const int ii = i * C; |
| | | const F kk = k[ii]; |
| | | const F vv = v[ii]; |
| | | const F yy = y[ii]; |
| | | const F qq = q[i]; |
| | | const F rr = r[i]; |
| | | |
| | | F e1 = qq * exp(rr); |
| | | F e2 = exp(kk + pp); |
| | | gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb); |
| | | gv[ii] = e1 + e2 * aa; |
| | | |
| | | const F ww = w + pp; |
| | | const F www = rr - u - kk; |
| | | const F p = max(ww, www); |
| | | e1 = exp(ww - p); |
| | | e2 = qq * exp(www - p); |
| | | aa = e1 * aa + e2; |
| | | bb = e1 * bb - e2 * yy; |
| | | pp = p; |
| | | } |
| | | } |
| | | |
| | | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { |
| | | dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance |
| | | assert(B * C % threadsPerBlock.x == 0); |
| | | dim3 numBlocks(B * C / threadsPerBlock.x); |
| | | kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y); |
| | | } |
| | | |
| | | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) { |
| | | dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance |
| | | assert(B * C % threadsPerBlock.x == 0); |
| | | dim3 numBlocks(B * C / threadsPerBlock.x); |
| | | kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv); |
| | | } |
| New file |
| | |
| | | /* |
| | | * Bsed on https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/cuda/wkv_op.cpp |
| | | Function signatures were modified based on https://github.com/huggingface/transformers/blob/main/src/transformers/kernels/rwkv/wkv_op.cpp |
| | | |
| | | */ |
| | | |
| | | #include <torch/extension.h> |
| | | |
| | | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); |
| | | |
| | | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv); |
| | | |
| | | void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { |
| | | const int B = k.size(0); |
| | | const int T = k.size(1); |
| | | const int C = k.size(2); |
| | | |
| | | cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>()); |
| | | } |
| | | |
| | | void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { |
| | | const int B = k.size(0); |
| | | const int T = k.size(1); |
| | | const int C = k.size(2); |
| | | |
| | | cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>()); |
| | | } |
| | | |
| | | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
| | | m.def("forward", &forward, "wkv forward"); |
| | | m.def("backward", &backward, "wkv backward"); |
| | | } |
| | | |
| | | TORCH_LIBRARY(wkv_decoder, m) { |
| | | m.def("forward", forward); |
| | | m.def("backward", backward); |
| | | } |
| New file |
| | |
| | | // Copied from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/cuda/wkv_cuda.cu |
| | | |
| | | #include <stdio.h> |
| | | #include <assert.h> |
| | | |
| | | #define MIN_VALUE (-1e38) |
| | | |
| | | template <typename F> |
| | | __global__ void kernel_forward(const int B, const int T, const int C, |
| | | const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, |
| | | F *__restrict__ const _y) { |
| | | const int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| | | const int _b = idx / C; |
| | | const int _c = idx % C; |
| | | const int _offset = _b * T * C + _c; |
| | | |
| | | F u = _u[_c]; |
| | | F w = _w[_c]; |
| | | const F *__restrict__ const k = _k + _offset; |
| | | const F *__restrict__ const v = _v + _offset; |
| | | F *__restrict__ const y = _y + _offset; |
| | | |
| | | // aa and bb are running sums divided by exp(pp) (to avoid overflow) |
| | | F aa = 0, bb = 0, pp = MIN_VALUE; |
| | | for (int i = 0; i < T; i++) { |
| | | const int ii = i * C; |
| | | const F kk = k[ii]; |
| | | const F vv = v[ii]; |
| | | |
| | | F ww = u + kk; |
| | | F p = max(pp, ww); |
| | | F e1 = exp(pp - p); |
| | | F e2 = exp(ww - p); |
| | | y[ii] = (e1 * aa + e2 * vv) / (e1 * bb + e2); |
| | | |
| | | ww = w + pp; |
| | | p = max(ww, kk); |
| | | e1 = exp(ww - p); |
| | | e2 = exp(kk - p); |
| | | aa = e1 * aa + e2 * vv; |
| | | bb = e1 * bb + e2; |
| | | pp = p; |
| | | } |
| | | } |
| | | |
| | | template <typename F> |
| | | __global__ void kernel_backward(const int B, const int T, const int C, |
| | | const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, |
| | | const F *__restrict__ const _y, const F *__restrict__ const _gy, |
| | | F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) { |
| | | const int idx = blockIdx.x * blockDim.x + threadIdx.x; |
| | | const int _b = idx / C; |
| | | const int _c = idx % C; |
| | | const int _offset = _b * T * C + _c; |
| | | |
| | | F u = _u[_c]; |
| | | F w = _w[_c]; |
| | | const F *__restrict__ const k = _k + _offset; |
| | | const F *__restrict__ const v = _v + _offset; |
| | | const F *__restrict__ const y = _y + _offset; |
| | | const F *__restrict__ const gy = _gy + _offset; |
| | | F *__restrict__ const gk = _gk + _offset; |
| | | F *__restrict__ const gv = _gv + _offset; |
| | | |
| | | F q[Tmax], r[Tmax]; |
| | | |
| | | F gw = 0, gu = 0, aa = 0, bb = 0, ga = 0, gb = 0, pp = MIN_VALUE; |
| | | for (int i = 0; i < T; i++) { |
| | | const int ii = i * C; |
| | | const F kk = k[ii]; |
| | | const F vv = v[ii]; |
| | | const F yy = y[ii]; |
| | | |
| | | F ww = u + kk; |
| | | F p = max(pp, ww); |
| | | F e1 = exp(pp - p); |
| | | F e2 = exp(ww - p); |
| | | const F qq = gy[ii] / (e1 * bb + e2); |
| | | gw += (ga - gb * yy) * e1 * qq; |
| | | gu += (vv - yy) * e2 * qq; |
| | | q[i] = qq; |
| | | r[i] = ww - p; |
| | | |
| | | ww = w + pp; |
| | | p = max(ww, kk); |
| | | e1 = exp(ww - p); |
| | | e2 = exp(kk - p); |
| | | ga = e1 * (aa + ga); |
| | | gb = e1 * (bb + gb); |
| | | aa = e1 * aa + e2 * vv; |
| | | bb = e1 * bb + e2; |
| | | pp = p; |
| | | } |
| | | const int _offsetBC = _b * C + _c; |
| | | _gw[_offsetBC] = gw * _w[_c]; // multiply by w because of w -> -exp(w) in python forward() |
| | | _gu[_offsetBC] = gu; |
| | | |
| | | aa = 0, bb = 0, pp = MIN_VALUE; |
| | | for (int i = T - 1; i >= 0; i--) { |
| | | const int ii = i * C; |
| | | const F kk = k[ii]; |
| | | const F vv = v[ii]; |
| | | const F yy = y[ii]; |
| | | const F qq = q[i]; |
| | | const F rr = r[i]; |
| | | |
| | | F e1 = qq * exp(rr); |
| | | F e2 = exp(kk + pp); |
| | | gk[ii] = e1 * (vv - yy) + e2 * (aa * vv + bb); |
| | | gv[ii] = e1 + e2 * aa; |
| | | |
| | | const F ww = w + pp; |
| | | const F www = rr - u - kk; |
| | | const F p = max(ww, www); |
| | | e1 = exp(ww - p); |
| | | e2 = qq * exp(www - p); |
| | | aa = e1 * aa + e2; |
| | | bb = e1 * bb - e2 * yy; |
| | | pp = p; |
| | | } |
| | | } |
| | | |
| | | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) { |
| | | dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance |
| | | assert(B * C % threadsPerBlock.x == 0); |
| | | dim3 numBlocks(B * C / threadsPerBlock.x); |
| | | kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y); |
| | | } |
| | | |
| | | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv) { |
| | | dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance |
| | | assert(B * C % threadsPerBlock.x == 0); |
| | | dim3 numBlocks(B * C / threadsPerBlock.x); |
| | | kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y, gy, gw, gu, gk, gv); |
| | | } |
| New file |
| | |
| | | /* |
| | | * Bsed on https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/cuda/wkv_op.cpp |
| | | Function signatures were modified based on https://github.com/huggingface/transformers/blob/main/src/transformers/kernels/rwkv/wkv_op.cpp |
| | | |
| | | */ |
| | | |
| | | #include <torch/extension.h> |
| | | |
| | | void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y); |
| | | |
| | | void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y, float *gy, float *gw, float *gu, float *gk, float *gv); |
| | | |
| | | void forward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) { |
| | | const int B = k.size(0); |
| | | const int T = k.size(1); |
| | | const int C = k.size(2); |
| | | |
| | | cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>()); |
| | | } |
| | | |
| | | void backward(torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) { |
| | | const int B = k.size(0); |
| | | const int T = k.size(1); |
| | | const int C = k.size(2); |
| | | |
| | | cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>()); |
| | | } |
| | | |
| | | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { |
| | | m.def("forward", &forward, "wkv forward"); |
| | | m.def("backward", &backward, "wkv backward"); |
| | | } |
| | | |
| | | TORCH_LIBRARY(wkv_encoder, m) { |
| | | m.def("forward", forward); |
| | | m.def("backward", backward); |
| | | } |
| | |
| | | |
| | | """ |
| | | num_state, den_state, max_state = state |
| | | |
| | | time_decay = -torch.exp(time_decay) |
| | | max_for_output = torch.maximum(max_state, (time_first + key)) |
| | | |
| | | e1 = torch.exp(max_state - max_for_output) |
| | |
| | | dropout_rate, |
| | | num_blocks |
| | | ) |
| | | load_decoder_wkv_kernel(context_size) |
| | | # load_decoder_wkv_kernel(context_size) |
| | | |
| | | def forward( |
| | | self, |
| | |
| | | dropout_rate, |
| | | num_blocks |
| | | ) |
| | | load_encoder_wkv_kernel(context_size) |
| | | # load_encoder_wkv_kernel(context_size) |
| | | |
| | | def forward( |
| | | self, |