| | |
| | | x = self.output(x * g) |
| | | return x |
| | | |
| | | def forward(self, x): |
| | | def forward(self, x, **kwargs): |
| | | B, T, C = x.size() |
| | | H = self.n_head |
| | | |
| | |
| | | self.ln1 = None |
| | | if args.get("ln1", True): |
| | | self.ln1 = nn.LayerNorm(args.n_embd) |
| | | self.ln2 = nn.LayerNorm(args.n_embd) |
| | | |
| | | self.att = RWKV_Tmix_x060(args, layer_id) |
| | | |
| | | self.ffn = RWKV_CMix_x060(args, layer_id) |
| | | self.ln2 = None |
| | | self.ffn = None |
| | | if args.get("use_rwkv_ffn", True): |
| | | self.ln2 = nn.LayerNorm(args.n_embd) |
| | | self.ffn = RWKV_CMix_x060(args, layer_id) |
| | | |
| | | if args.dropout > 0: |
| | | self.drop0 = nn.Dropout(p=args.dropout) |
| | |
| | | nn.init.zeros_(self.ffn.value.weight) |
| | | nn.init.zeros_(self.ffn.receptance.weight) |
| | | scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7 |
| | | nn.init.constant_(self.ln2.weight, scale) |
| | | |
| | | if self.ln0 is not None: |
| | | nn.init.constant_(self.ln0.weight, scale) |
| | | if self.ln1 is not None: |
| | | nn.init.constant_(self.ln1.weight, scale) |
| | | if self.ln2 is not None: |
| | | nn.init.constant_(self.ln2.weight, scale) |
| | | |
| | | def forward(self, x, x_emb=None, mask=None, **kwargs): |
| | | |
| | |
| | | x = x + self.att(x) |
| | | else: |
| | | x = x + self.att(self.ln1(x)) |
| | | x = x + self.ffn(self.ln2(x)) |
| | | if self.ffn is not None: |
| | | x = x + self.ffn(self.ln2(x)) |
| | | else: |
| | | if self.ln1 is None: |
| | | x = self.drop0(x + self.att(x)) |
| | | else: |
| | | x = self.drop0(x + self.att(self.ln1(x))) |
| | | x = self.drop1(x + self.ffn(self.ln2(x))) |
| | | if self.ffn is not None: |
| | | x = self.drop1(x + self.ffn(self.ln2(x))) |
| | | |
| | | if args.get("datatype", "bf16") == "bf16": |
| | | x = x.to(torch.float32) |