| | |
| | | |
| | | import os, math, gc, importlib |
| | | import torch |
| | | |
| | | # torch._C._jit_set_profiling_executor(True) |
| | | # torch._C._jit_set_profiling_mode(True) |
| | | import torch.nn as nn |
| | | from torch.nn import functional as F |
| | | |
| | | |
| | | |
| | | def __nop(ob): |
| | | return ob |
| | | return ob |
| | | |
| | | |
| | | MyModule = nn.Module |
| | | MyFunction = __nop |
| | | if "RWKV_JIT_ON" in os.environ and os.environ["RWKV_JIT_ON"] == "1": |
| | | MyModule = torch.jit.ScriptModule |
| | | MyFunction = torch.jit.script_method |
| | | MyModule = torch.jit.ScriptModule |
| | | MyFunction = torch.jit.script_method |
| | | |
| | | ######################################################################################################## |
| | | # CUDA Kernel |
| | |
| | | |
| | | wkv6_cuda = None |
| | | |
| | | def load_rwkv_kernel(HEAD_SIZE: int=64, RWKV_CTXLEN: int=512,): |
| | | from torch.utils.cpp_extension import load |
| | | global wkv6_cuda |
| | | |
| | | |
| | | if wkv6_cuda is not None: |
| | | return |
| | | |
| | | absolute_file_path = os.path.abspath(__file__) |
| | | cur_dir = os.path.dirname(absolute_file_path) |
| | | wkv6_cuda = load(name="wkv6", sources=[f"{cur_dir}/cuda/wkv6_op.cpp", f"{cur_dir}/cuda/wkv6_cuda.cu"], |
| | | verbose=True, extra_cuda_cflags=["-res-usage", "--use_fast_math", "-O3", "-Xptxas -O3", |
| | | "--extra-device-vectorization", f"-D_N_={HEAD_SIZE}", |
| | | f"-D_T_={RWKV_CTXLEN}"]) |
| | | |
| | | def load_rwkv_kernel( |
| | | HEAD_SIZE: int = 64, |
| | | RWKV_CTXLEN: int = 512, |
| | | ): |
| | | from torch.utils.cpp_extension import load |
| | | |
| | | global wkv6_cuda |
| | | |
| | | if wkv6_cuda is not None: |
| | | return |
| | | |
| | | absolute_file_path = os.path.abspath(__file__) |
| | | cur_dir = os.path.dirname(absolute_file_path) |
| | | wkv6_cuda = load( |
| | | name="wkv6", |
| | | sources=[f"{cur_dir}/cuda/wkv6_op.cpp", f"{cur_dir}/cuda/wkv6_cuda.cu"], |
| | | verbose=True, |
| | | extra_cuda_cflags=[ |
| | | "-res-usage", |
| | | "--use_fast_math", |
| | | "-O3", |
| | | "-Xptxas -O3", |
| | | "--extra-device-vectorization", |
| | | f"-D_N_={HEAD_SIZE}", |
| | | f"-D_T_={RWKV_CTXLEN}", |
| | | ], |
| | | ) |
| | | |
| | | |
| | | # dtype = torch.float |
| | | dtype = torch.bfloat16 |
| | | |
| | | |
| | | class WKV_6(torch.autograd.Function): |
| | | @staticmethod |
| | | def forward(ctx, B, T, C, H, r, k, v, w, u): |
| | | with torch.no_grad(): |
| | | # assert r.dtype == torch.bfloat16 |
| | | # assert k.dtype == torch.bfloat16 |
| | | # assert v.dtype == torch.bfloat16 |
| | | # assert w.dtype == torch.bfloat16 |
| | | # assert u.dtype == torch.bfloat16 |
| | | # assert HEAD_SIZE == C // H |
| | | ctx.B = B |
| | | ctx.T = T |
| | | ctx.C = C |
| | | ctx.H = H |
| | | assert r.is_contiguous() |
| | | assert k.is_contiguous() |
| | | assert v.is_contiguous() |
| | | assert w.is_contiguous() |
| | | assert u.is_contiguous() |
| | | ew = (-torch.exp(w.float())).contiguous() |
| | | ctx.save_for_backward(r, k, v, ew, u) |
| | | y = torch.empty((B, T, C), device=r.device, dtype=dtype, |
| | | memory_format=torch.contiguous_format) # .uniform_(-100, 100) |
| | | wkv6_cuda.forward(B, T, C, H, r, k, v, ew, u, y) |
| | | return y |
| | | |
| | | @staticmethod |
| | | def backward(ctx, gy): |
| | | with torch.no_grad(): |
| | | # assert gy.dtype == torch.bfloat16 |
| | | B = ctx.B |
| | | T = ctx.T |
| | | C = ctx.C |
| | | H = ctx.H |
| | | assert gy.is_contiguous() |
| | | r, k, v, ew, u = ctx.saved_tensors |
| | | gr = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=dtype, |
| | | memory_format=torch.contiguous_format) # .uniform_(-100, 100) |
| | | gk = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=dtype, |
| | | memory_format=torch.contiguous_format) # .uniform_(-100, 100) |
| | | gv = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=dtype, |
| | | memory_format=torch.contiguous_format) # .uniform_(-100, 100) |
| | | gw = torch.empty((B, T, C), device=gy.device, requires_grad=False, dtype=dtype, |
| | | memory_format=torch.contiguous_format) # .uniform_(-100, 100) |
| | | gu = torch.empty((B, C), device=gy.device, requires_grad=False, dtype=dtype, |
| | | memory_format=torch.contiguous_format) # .uniform_(-100, 100) |
| | | wkv6_cuda.backward(B, T, C, H, r, k, v, ew, u, gy, gr, gk, gv, gw, gu) |
| | | gu = torch.sum(gu, 0).view(H, C // H) |
| | | return (None, None, None, None, gr, gk, gv, gw, gu) |
| | | @staticmethod |
| | | def forward(ctx, B, T, C, H, r, k, v, w, u): |
| | | with torch.no_grad(): |
| | | # assert r.dtype == torch.bfloat16 |
| | | # assert k.dtype == torch.bfloat16 |
| | | # assert v.dtype == torch.bfloat16 |
| | | # assert w.dtype == torch.bfloat16 |
| | | # assert u.dtype == torch.bfloat16 |
| | | # assert HEAD_SIZE == C // H |
| | | ctx.B = B |
| | | ctx.T = T |
| | | ctx.C = C |
| | | ctx.H = H |
| | | assert r.is_contiguous() |
| | | assert k.is_contiguous() |
| | | assert v.is_contiguous() |
| | | assert w.is_contiguous() |
| | | assert u.is_contiguous() |
| | | ew = (-torch.exp(w.float())).contiguous() |
| | | ctx.save_for_backward(r, k, v, ew, u) |
| | | y = torch.empty( |
| | | (B, T, C), device=r.device, dtype=dtype, memory_format=torch.contiguous_format |
| | | ) # .uniform_(-100, 100) |
| | | wkv6_cuda.forward(B, T, C, H, r, k, v, ew, u, y) |
| | | return y |
| | | |
| | | @staticmethod |
| | | def backward(ctx, gy): |
| | | with torch.no_grad(): |
| | | # assert gy.dtype == torch.bfloat16 |
| | | B = ctx.B |
| | | T = ctx.T |
| | | C = ctx.C |
| | | H = ctx.H |
| | | assert gy.is_contiguous() |
| | | r, k, v, ew, u = ctx.saved_tensors |
| | | gr = torch.empty( |
| | | (B, T, C), |
| | | device=gy.device, |
| | | requires_grad=False, |
| | | dtype=dtype, |
| | | memory_format=torch.contiguous_format, |
| | | ) # .uniform_(-100, 100) |
| | | gk = torch.empty( |
| | | (B, T, C), |
| | | device=gy.device, |
| | | requires_grad=False, |
| | | dtype=dtype, |
| | | memory_format=torch.contiguous_format, |
| | | ) # .uniform_(-100, 100) |
| | | gv = torch.empty( |
| | | (B, T, C), |
| | | device=gy.device, |
| | | requires_grad=False, |
| | | dtype=dtype, |
| | | memory_format=torch.contiguous_format, |
| | | ) # .uniform_(-100, 100) |
| | | gw = torch.empty( |
| | | (B, T, C), |
| | | device=gy.device, |
| | | requires_grad=False, |
| | | dtype=dtype, |
| | | memory_format=torch.contiguous_format, |
| | | ) # .uniform_(-100, 100) |
| | | gu = torch.empty( |
| | | (B, C), |
| | | device=gy.device, |
| | | requires_grad=False, |
| | | dtype=dtype, |
| | | memory_format=torch.contiguous_format, |
| | | ) # .uniform_(-100, 100) |
| | | wkv6_cuda.backward(B, T, C, H, r, k, v, ew, u, gy, gr, gk, gv, gw, gu) |
| | | gu = torch.sum(gu, 0).view(H, C // H) |
| | | return (None, None, None, None, gr, gk, gv, gw, gu) |
| | | |
| | | |
| | | def RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u): |
| | | return WKV_6.apply(B, T, C, H, r, k, v, w, u) |
| | | return WKV_6.apply(B, T, C, H, r, k, v, w, u) |
| | | |
| | | |
| | | class RWKV_Tmix_x060(MyModule): |
| | | def __init__(self, args, layer_id): |
| | | super().__init__() |
| | | self.args = args |
| | | |
| | | load_rwkv_kernel(args.head_size_a, args.ctx_len) |
| | | |
| | | self.layer_id = layer_id |
| | | |
| | | self.head_size = args.head_size_a |
| | | self.n_head = args.dim_att // self.head_size |
| | | assert args.dim_att % self.n_head == 0 |
| | | |
| | | with torch.no_grad(): |
| | | ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1 |
| | | ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 |
| | | ddd = torch.ones(1, 1, args.n_embd) |
| | | for i in range(args.n_embd): |
| | | ddd[0, 0, i] = i / args.n_embd |
| | | |
| | | # fancy time_mix |
| | | self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) |
| | | self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) |
| | | self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) |
| | | self.time_maa_v = nn.Parameter(1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)) |
| | | self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)) |
| | | self.time_maa_g = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)) |
| | | |
| | | D_MIX_LORA = 32 # generate TIME_MIX for w,k,v,r,g |
| | | self.time_maa_w1 = nn.Parameter(torch.zeros(args.n_embd, D_MIX_LORA * 5)) |
| | | self.time_maa_w2 = nn.Parameter(torch.zeros(5, D_MIX_LORA, args.n_embd).uniform_(-0.01, 0.01)) |
| | | |
| | | # fancy time_decay |
| | | decay_speed = torch.ones(args.dim_att) |
| | | for n in range(args.dim_att): |
| | | decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1) |
| | | self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.dim_att)) |
| | | |
| | | D_DECAY_LORA = 64 |
| | | self.time_decay_w1 = nn.Parameter(torch.zeros(args.n_embd, D_DECAY_LORA)) |
| | | self.time_decay_w2 = nn.Parameter(torch.zeros(D_DECAY_LORA, args.dim_att).uniform_(-0.01, 0.01)) |
| | | |
| | | tmp = torch.zeros(args.dim_att) |
| | | for n in range(args.dim_att): |
| | | zigzag = ((n + 1) % 3 - 1) * 0.1 |
| | | tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag |
| | | |
| | | self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size)) |
| | | |
| | | self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) |
| | | self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False) |
| | | self.key = nn.Linear(args.n_embd, args.dim_att, bias=False) |
| | | |
| | | self.value = nn.Linear(args.n_embd, args.dim_att, bias=False) |
| | | self.output = nn.Linear(args.dim_att, args.n_embd, bias=False) |
| | | self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False) |
| | | self.ln_x = nn.GroupNorm(self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor ** 2)) |
| | | |
| | | @MyFunction |
| | | def jit_func(self, x): |
| | | B, T, C = x.size() |
| | | |
| | | xx = self.time_shift(x) - x |
| | | |
| | | xxx = x + xx * self.time_maa_x |
| | | xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1) |
| | | xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1) |
| | | mw, mk, mv, mr, mg = xxx.unbind(dim=0) |
| | | |
| | | xw = x + xx * (self.time_maa_w + mw) |
| | | xk = x + xx * (self.time_maa_k + mk) |
| | | xv = x + xx * (self.time_maa_v + mv) |
| | | xr = x + xx * (self.time_maa_r + mr) |
| | | xg = x + xx * (self.time_maa_g + mg) |
| | | |
| | | r = self.receptance(xr) |
| | | k = self.key(xk) |
| | | v = self.value(xv) |
| | | g = F.silu(self.gate(xg)) |
| | | |
| | | ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2 |
| | | w = self.time_decay + ww |
| | | |
| | | return r, k, v, g, w |
| | | |
| | | @MyFunction |
| | | def jit_func_2(self, x, g): |
| | | B, T, C = x.size() |
| | | x = x.view(B * T, C) |
| | | |
| | | x = self.ln_x(x).view(B, T, C) |
| | | x = self.output(x * g) |
| | | return x |
| | | |
| | | def forward(self, x): |
| | | B, T, C = x.size() |
| | | H = self.n_head |
| | | |
| | | r, k, v, g, w = self.jit_func(x) |
| | | x = RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u=self.time_faaaa) |
| | | |
| | | return self.jit_func_2(x, g) |
| | | def __init__(self, args, layer_id): |
| | | super().__init__() |
| | | self.args = args |
| | | |
| | | load_rwkv_kernel(args.head_size_a, args.ctx_len) |
| | | |
| | | self.layer_id = layer_id |
| | | |
| | | self.head_size = args.head_size_a |
| | | self.n_head = args.dim_att // self.head_size |
| | | assert args.dim_att % self.n_head == 0 |
| | | |
| | | with torch.no_grad(): |
| | | ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1 |
| | | ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 |
| | | ddd = torch.ones(1, 1, args.n_embd) |
| | | for i in range(args.n_embd): |
| | | ddd[0, 0, i] = i / args.n_embd |
| | | |
| | | # fancy time_mix |
| | | self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) |
| | | self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) |
| | | self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) |
| | | self.time_maa_v = nn.Parameter( |
| | | 1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1) |
| | | ) |
| | | self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)) |
| | | self.time_maa_g = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0)) |
| | | |
| | | D_MIX_LORA = 32 # generate TIME_MIX for w,k,v,r,g |
| | | self.time_maa_w1 = nn.Parameter(torch.zeros(args.n_embd, D_MIX_LORA * 5)) |
| | | self.time_maa_w2 = nn.Parameter( |
| | | torch.zeros(5, D_MIX_LORA, args.n_embd).uniform_(-0.01, 0.01) |
| | | ) |
| | | |
| | | # fancy time_decay |
| | | decay_speed = torch.ones(args.dim_att) |
| | | for n in range(args.dim_att): |
| | | decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1) |
| | | self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.dim_att)) |
| | | |
| | | D_DECAY_LORA = 64 |
| | | self.time_decay_w1 = nn.Parameter(torch.zeros(args.n_embd, D_DECAY_LORA)) |
| | | self.time_decay_w2 = nn.Parameter( |
| | | torch.zeros(D_DECAY_LORA, args.dim_att).uniform_(-0.01, 0.01) |
| | | ) |
| | | |
| | | tmp = torch.zeros(args.dim_att) |
| | | for n in range(args.dim_att): |
| | | zigzag = ((n + 1) % 3 - 1) * 0.1 |
| | | tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag |
| | | |
| | | self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size)) |
| | | |
| | | self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) |
| | | self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False) |
| | | self.key = nn.Linear(args.n_embd, args.dim_att, bias=False) |
| | | |
| | | self.value = nn.Linear(args.n_embd, args.dim_att, bias=False) |
| | | self.output = nn.Linear(args.dim_att, args.n_embd, bias=False) |
| | | self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False) |
| | | self.ln_x = nn.GroupNorm( |
| | | self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor**2) |
| | | ) |
| | | |
| | | @MyFunction |
| | | def jit_func(self, x): |
| | | B, T, C = x.size() |
| | | |
| | | xx = self.time_shift(x) - x |
| | | |
| | | xxx = x + xx * self.time_maa_x |
| | | xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1) |
| | | xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1) |
| | | mw, mk, mv, mr, mg = xxx.unbind(dim=0) |
| | | |
| | | xw = x + xx * (self.time_maa_w + mw) |
| | | xk = x + xx * (self.time_maa_k + mk) |
| | | xv = x + xx * (self.time_maa_v + mv) |
| | | xr = x + xx * (self.time_maa_r + mr) |
| | | xg = x + xx * (self.time_maa_g + mg) |
| | | |
| | | r = self.receptance(xr) |
| | | k = self.key(xk) |
| | | v = self.value(xv) |
| | | g = F.silu(self.gate(xg)) |
| | | |
| | | ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2 |
| | | w = self.time_decay + ww |
| | | |
| | | return r, k, v, g, w |
| | | |
| | | @MyFunction |
| | | def jit_func_2(self, x, g): |
| | | B, T, C = x.size() |
| | | x = x.view(B * T, C) |
| | | |
| | | x = self.ln_x(x).view(B, T, C) |
| | | x = self.output(x * g) |
| | | return x |
| | | |
| | | def forward(self, x, **kwargs): |
| | | B, T, C = x.size() |
| | | H = self.n_head |
| | | |
| | | r, k, v, g, w = self.jit_func(x) |
| | | x = RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u=self.time_faaaa) |
| | | |
| | | return self.jit_func_2(x, g) |
| | | |
| | | |
| | | class RWKV_CMix_x060(MyModule): |
| | | def __init__(self, args, layer_id): |
| | | super().__init__() |
| | | self.args = args |
| | | self.layer_id = layer_id |
| | | self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) |
| | | |
| | | with torch.no_grad(): # fancy init of time_mix |
| | | ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 |
| | | ddd = torch.ones(1, 1, args.n_embd) |
| | | for i in range(args.n_embd): |
| | | ddd[0, 0, i] = i / args.n_embd |
| | | self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) |
| | | self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) |
| | | |
| | | self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False) |
| | | self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False) |
| | | self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) |
| | | |
| | | @MyFunction |
| | | def forward(self, x): |
| | | xx = self.time_shift(x) - x |
| | | xk = x + xx * self.time_maa_k |
| | | xr = x + xx * self.time_maa_r |
| | | |
| | | k = self.key(xk) |
| | | k = torch.relu(k) ** 2 |
| | | kv = self.value(k) |
| | | return torch.sigmoid(self.receptance(xr)) * kv |
| | | def __init__(self, args, layer_id): |
| | | super().__init__() |
| | | self.args = args |
| | | self.layer_id = layer_id |
| | | self.time_shift = nn.ZeroPad2d((0, 0, 1, -1)) |
| | | |
| | | with torch.no_grad(): # fancy init of time_mix |
| | | ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0 |
| | | ddd = torch.ones(1, 1, args.n_embd) |
| | | for i in range(args.n_embd): |
| | | ddd[0, 0, i] = i / args.n_embd |
| | | self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) |
| | | self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0)) |
| | | |
| | | self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False) |
| | | self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False) |
| | | self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False) |
| | | |
| | | @MyFunction |
| | | def forward(self, x): |
| | | xx = self.time_shift(x) - x |
| | | xk = x + xx * self.time_maa_k |
| | | xr = x + xx * self.time_maa_r |
| | | |
| | | k = self.key(xk) |
| | | k = torch.relu(k) ** 2 |
| | | kv = self.value(k) |
| | | return torch.sigmoid(self.receptance(xr)) * kv |
| | | |
| | | |
| | | class Block(nn.Module): |
| | | def __init__(self, args, layer_id): |
| | | super().__init__() |
| | | self.args = args |
| | | self.layer_id = layer_id |
| | | |
| | | self.ln1 = nn.LayerNorm(args.n_embd) |
| | | self.ln2 = nn.LayerNorm(args.n_embd) |
| | | |
| | | if self.layer_id == 0: |
| | | self.ln0 = nn.LayerNorm(args.n_embd) |
| | | def __init__(self, args, layer_id): |
| | | super().__init__() |
| | | self.args = args |
| | | self.layer_id = layer_id |
| | | |
| | | self.ln1 = nn.LayerNorm(args.n_embd) |
| | | self.ln2 = nn.LayerNorm(args.n_embd) |
| | | |
| | | self.att = RWKV_Tmix_x060(args, layer_id) |
| | | |
| | | self.ffn = RWKV_CMix_x060(args, layer_id) |
| | | if self.layer_id == 0: |
| | | self.ln0 = nn.LayerNorm(args.n_embd) |
| | | |
| | | self.att = RWKV_Tmix_x060(args, layer_id) |
| | | |
| | | if args.dropout > 0: |
| | | self.drop0 = nn.Dropout(p=args.dropout) |
| | | self.drop1 = nn.Dropout(p=args.dropout) |
| | | |
| | | def forward(self, x, x_emb=None): |
| | | args = self.args |
| | | B, T, C = x.size() |
| | | if self.layer_id == 0: |
| | | x = self.ln0(x) |
| | | self.ffn = RWKV_CMix_x060(args, layer_id) |
| | | |
| | | |
| | | if self.args.dropout == 0: |
| | | if self.layer_id == 0 and args.pre_ffn > 0: |
| | | x = x + self.ffnPre(self.ln1(x)) |
| | | else: |
| | | x = x + self.att(self.ln1(x)) |
| | | x = x + self.ffn(self.ln2(x)) |
| | | else: |
| | | if self.layer_id == 0 and args.pre_ffn > 0: |
| | | x = self.drop0(x + self.ffnPre(self.ln1(x))) |
| | | else: |
| | | x = self.drop0(x + self.att(self.ln1(x))) |
| | | x = self.drop1(x + self.ffn(self.ln2(x))) |
| | | |
| | | return x |
| | | if args.dropout > 0: |
| | | self.drop0 = nn.Dropout(p=args.dropout) |
| | | self.drop1 = nn.Dropout(p=args.dropout) |
| | | |
| | | def forward(self, x, x_emb=None): |
| | | args = self.args |
| | | B, T, C = x.size() |
| | | if self.layer_id == 0: |
| | | x = self.ln0(x) |
| | | |
| | | if self.args.dropout == 0: |
| | | if self.layer_id == 0 and args.pre_ffn > 0: |
| | | x = x + self.ffnPre(self.ln1(x)) |
| | | else: |
| | | x = x + self.att(self.ln1(x)) |
| | | x = x + self.ffn(self.ln2(x)) |
| | | else: |
| | | if self.layer_id == 0 and args.pre_ffn > 0: |
| | | x = self.drop0(x + self.ffnPre(self.ln1(x))) |
| | | else: |
| | | x = self.drop0(x + self.att(self.ln1(x))) |
| | | x = self.drop1(x + self.ffn(self.ln2(x))) |
| | | |
| | | return x |
| | | |
| | | |
| | | class RWKVLayer(nn.Module): |
| | | def __init__(self, args, layer_id): |
| | | super().__init__() |
| | | self.args = args |
| | | self.layer_id = layer_id |
| | | if args.dim_ffn is None: |
| | | args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32) |
| | | self.ln0 = None |
| | | if self.layer_id == 0 and args.get("ln0", True): |
| | | self.ln0 = nn.LayerNorm(args.n_embd) |
| | | |
| | | self.ln1 = None |
| | | if args.get("ln1", True): |
| | | self.ln1 = nn.LayerNorm(args.n_embd) |
| | | self.ln2 = nn.LayerNorm(args.n_embd) |
| | | |
| | | def __init__(self, args, layer_id): |
| | | super().__init__() |
| | | self.args = args |
| | | self.layer_id = layer_id |
| | | if args.dim_ffn is None: |
| | | args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32) |
| | | self.ln0 = None |
| | | if self.layer_id == 0 and args.get("ln0", True): |
| | | self.ln0 = nn.LayerNorm(args.n_embd) |
| | | |
| | | self.att = RWKV_Tmix_x060(args, layer_id) |
| | | |
| | | self.ffn = RWKV_CMix_x060(args, layer_id) |
| | | |
| | | if args.dropout > 0: |
| | | self.drop0 = nn.Dropout(p=args.dropout) |
| | | self.drop1 = nn.Dropout(p=args.dropout) |
| | | |
| | | # init |
| | | if args.get("init_rwkv", True): |
| | | print("init_rwkv") |
| | | nn.init.orthogonal_(self.att.receptance.weight, gain=1) |
| | | nn.init.orthogonal_(self.att.key.weight, gain=0.1) |
| | | nn.init.orthogonal_(self.att.value.weight, gain=1) |
| | | nn.init.orthogonal_(self.att.gate.weight, gain=0.1) |
| | | nn.init.zeros_(self.att.output.weight) |
| | | |
| | | nn.init.orthogonal_(self.ffn.key.weight, gain=1) |
| | | nn.init.zeros_(self.ffn.value.weight) |
| | | nn.init.zeros_(self.ffn.receptance.weight) |
| | | scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7 |
| | | nn.init.constant_(self.ln2.weight, scale) |
| | | if self.ln0 is not None: |
| | | nn.init.constant_(self.ln0.weight, scale) |
| | | if self.ln1 is not None: |
| | | nn.init.constant_(self.ln1.weight, scale) |
| | | |
| | | def forward(self, x, x_emb=None, mask=None, **kwargs): |
| | | |
| | | args = self.args |
| | | if args.get("datatype", "bf16") == "bf16": |
| | | x = x.bfloat16() |
| | | B, T, C = x.size() |
| | | if self.layer_id == 0 and self.ln0 is not None: |
| | | x = self.ln0(x) |
| | | |
| | | if self.args.dropout == 0: |
| | | if self.ln1 is None: |
| | | x = x + self.att(x) |
| | | else: |
| | | x = x + self.att(self.ln1(x)) |
| | | x = x + self.ffn(self.ln2(x)) |
| | | else: |
| | | if self.ln1 is None: |
| | | x = self.drop0(x + self.att(x)) |
| | | else: |
| | | x = self.drop0(x + self.att(self.ln1(x))) |
| | | x = self.drop1(x + self.ffn(self.ln2(x))) |
| | | |
| | | if args.get("datatype", "bf16") == "bf16": |
| | | x = x.to(torch.float32) |
| | | return x |
| | | self.ln1 = None |
| | | if args.get("ln1", True): |
| | | self.ln1 = nn.LayerNorm(args.n_embd) |
| | | |
| | | self.att = RWKV_Tmix_x060(args, layer_id) |
| | | |
| | | self.ln2 = None |
| | | self.ffn = None |
| | | if args.get("use_rwkv_ffn", True): |
| | | self.ln2 = nn.LayerNorm(args.n_embd) |
| | | self.ffn = RWKV_CMix_x060(args, layer_id) |
| | | |
| | | if args.dropout > 0: |
| | | self.drop0 = nn.Dropout(p=args.dropout) |
| | | self.drop1 = nn.Dropout(p=args.dropout) |
| | | |
| | | # init |
| | | if args.get("init_rwkv", True): |
| | | print("init_rwkv") |
| | | nn.init.orthogonal_(self.att.receptance.weight, gain=1) |
| | | nn.init.orthogonal_(self.att.key.weight, gain=0.1) |
| | | nn.init.orthogonal_(self.att.value.weight, gain=1) |
| | | nn.init.orthogonal_(self.att.gate.weight, gain=0.1) |
| | | nn.init.zeros_(self.att.output.weight) |
| | | |
| | | nn.init.orthogonal_(self.ffn.key.weight, gain=1) |
| | | nn.init.zeros_(self.ffn.value.weight) |
| | | nn.init.zeros_(self.ffn.receptance.weight) |
| | | scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7 |
| | | |
| | | if self.ln0 is not None: |
| | | nn.init.constant_(self.ln0.weight, scale) |
| | | if self.ln1 is not None: |
| | | nn.init.constant_(self.ln1.weight, scale) |
| | | if self.ln2 is not None: |
| | | nn.init.constant_(self.ln2.weight, scale) |
| | | |
| | | def forward(self, x, x_emb=None, mask=None, **kwargs): |
| | | |
| | | args = self.args |
| | | if args.get("datatype", "bf16") == "bf16": |
| | | x = x.bfloat16() |
| | | B, T, C = x.size() |
| | | if self.layer_id == 0 and self.ln0 is not None: |
| | | x = self.ln0(x) |
| | | |
| | | if self.args.dropout == 0: |
| | | if self.ln1 is None: |
| | | x = x + self.att(x) |
| | | else: |
| | | x = x + self.att(self.ln1(x)) |
| | | if self.ffn is not None: |
| | | x = x + self.ffn(self.ln2(x)) |
| | | else: |
| | | if self.ln1 is None: |
| | | x = self.drop0(x + self.att(x)) |
| | | else: |
| | | x = self.drop0(x + self.att(self.ln1(x))) |
| | | if self.ffn is not None: |
| | | x = self.drop1(x + self.ffn(self.ln2(x))) |
| | | |
| | | if args.get("datatype", "bf16") == "bf16": |
| | | x = x.to(torch.float32) |
| | | return x |
| | | |
| | | |
| | | class RWKV(nn.Module): |
| | | def __init__(self, args): |
| | | super().__init__() |
| | | self.args = args |
| | | if not hasattr(args, 'dim_att'): |
| | | args.dim_att = args.n_embd |
| | | if not hasattr(args, 'dim_ffn'): |
| | | if '-f4' in os.environ["RWKV_MY_TESTING"]: |
| | | args.dim_ffn = int((args.n_embd * 4) // 32 * 32) |
| | | else: |
| | | args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32) # default = 3.5x emb size |
| | | if not hasattr(args, 'tiny_att_layer'): |
| | | args.tiny_att_layer = -1 |
| | | if not hasattr(args, 'tiny_att_dim'): |
| | | args.tiny_att_dim = -1 |
| | | assert args.n_embd % 32 == 0 |
| | | assert args.dim_att % 32 == 0 |
| | | assert args.dim_ffn % 32 == 0 |
| | | |
| | | self.emb = nn.Embedding(args.vocab_size, args.n_embd) |
| | | |
| | | self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)]) |
| | | |
| | | self.ln_out = nn.LayerNorm(args.n_embd) |
| | | self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False) |
| | | def __init__(self, args): |
| | | super().__init__() |
| | | self.args = args |
| | | if not hasattr(args, "dim_att"): |
| | | args.dim_att = args.n_embd |
| | | if not hasattr(args, "dim_ffn"): |
| | | if "-f4" in os.environ["RWKV_MY_TESTING"]: |
| | | args.dim_ffn = int((args.n_embd * 4) // 32 * 32) |
| | | else: |
| | | args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32) # default = 3.5x emb size |
| | | if not hasattr(args, "tiny_att_layer"): |
| | | args.tiny_att_layer = -1 |
| | | if not hasattr(args, "tiny_att_dim"): |
| | | args.tiny_att_dim = -1 |
| | | assert args.n_embd % 32 == 0 |
| | | assert args.dim_att % 32 == 0 |
| | | assert args.dim_ffn % 32 == 0 |
| | | |
| | | self.emb = nn.Embedding(args.vocab_size, args.n_embd) |
| | | |
| | | if args.dropout > 0: |
| | | self.drop0 = nn.Dropout(p=args.dropout) |
| | | self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)]) |
| | | |
| | | self.ln_out = nn.LayerNorm(args.n_embd) |
| | | self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False) |
| | | |
| | | def forward(self, idx): |
| | | args = self.args |
| | | B, T = idx.size() |
| | | assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted." |
| | | |
| | | x = self.emb(idx) |
| | | x_emb = x |
| | | |
| | | if args.dropout > 0: |
| | | x = self.drop0(x) |
| | | if args.tiny_att_dim > 0: |
| | | for block in self.blocks: |
| | | if args.grad_cp == 1: |
| | | x = deepspeed.checkpointing.checkpoint(block, x, x_emb) |
| | | else: |
| | | x = block(x, x_emb) |
| | | else: |
| | | for block in self.blocks: |
| | | if args.grad_cp == 1: |
| | | x = deepspeed.checkpointing.checkpoint(block, x) |
| | | else: |
| | | x = block(x) |
| | | |
| | | x = self.ln_out(x) |
| | | |
| | | if args.head_qk > 0: |
| | | q = self.head_q(x)[:, :T, :] |
| | | k = self.head_k(x)[:, :T, :] |
| | | c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk) |
| | | c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) |
| | | |
| | | if "32" in os.environ["RWKV_FLOAT_MODE"]: |
| | | c = c @ F.one_hot(idx, num_classes=args.vocab_size) |
| | | elif os.environ["RWKV_FLOAT_MODE"] == "fp16": |
| | | c = c @ F.one_hot(idx, num_classes=args.vocab_size).half() |
| | | elif os.environ["RWKV_FLOAT_MODE"] == "bf16": |
| | | c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16() |
| | | |
| | | x = self.head(x) + c |
| | | else: |
| | | x = self.head(x) |
| | | |
| | | return x |
| | | if args.dropout > 0: |
| | | self.drop0 = nn.Dropout(p=args.dropout) |
| | | |
| | | def forward(self, idx): |
| | | args = self.args |
| | | B, T = idx.size() |
| | | assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted." |
| | | |
| | | x = self.emb(idx) |
| | | x_emb = x |
| | | |
| | | if args.dropout > 0: |
| | | x = self.drop0(x) |
| | | if args.tiny_att_dim > 0: |
| | | for block in self.blocks: |
| | | if args.grad_cp == 1: |
| | | x = deepspeed.checkpointing.checkpoint(block, x, x_emb) |
| | | else: |
| | | x = block(x, x_emb) |
| | | else: |
| | | for block in self.blocks: |
| | | if args.grad_cp == 1: |
| | | x = deepspeed.checkpointing.checkpoint(block, x) |
| | | else: |
| | | x = block(x) |
| | | |
| | | x = self.ln_out(x) |
| | | |
| | | if args.head_qk > 0: |
| | | q = self.head_q(x)[:, :T, :] |
| | | k = self.head_k(x)[:, :T, :] |
| | | c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk) |
| | | c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0) |
| | | |
| | | if "32" in os.environ["RWKV_FLOAT_MODE"]: |
| | | c = c @ F.one_hot(idx, num_classes=args.vocab_size) |
| | | elif os.environ["RWKV_FLOAT_MODE"] == "fp16": |
| | | c = c @ F.one_hot(idx, num_classes=args.vocab_size).half() |
| | | elif os.environ["RWKV_FLOAT_MODE"] == "bf16": |
| | | c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16() |
| | | |
| | | x = self.head(x) + c |
| | | else: |
| | | x = self.head(x) |
| | | |
| | | return x |