| | |
| | | return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach() |
| | | |
| | | |
| | | from funasr.models.sense_voice.rwkv_v6 import RWKVLayer |
| | | from omegaconf import OmegaConf |
| | | |
| | | |
| | |
| | | |
| | | rwkv_cfg = kwargs.get("rwkv_cfg", {}) |
| | | args = OmegaConf.create(rwkv_cfg) |
| | | self.attn = RWKVLayer(args=args, layer_id=layer_id) |
| | | if args.get("datatype", "bf16") == "bf16": |
| | | self.attn.to(torch.bfloat16) |
| | | if args.get("version", "v4") == "v4": |
| | | from funasr.models.sense_voice.rwkv_v4 import RWKVLayer |
| | | from funasr.models.sense_voice.rwkv_v4 import RWKV_TimeMix as RWKV_Tmix |
| | | elif args.get("version", "v5") == "v5": |
| | | from funasr.models.sense_voice.rwkv_v5 import RWKVLayer |
| | | from funasr.models.sense_voice.rwkv_v5 import RWKV_Tmix_x052 as RWKV_Tmix |
| | | else: |
| | | from funasr.models.sense_voice.rwkv_v6 import RWKVLayer |
| | | from funasr.models.sense_voice.rwkv_v6 import RWKV_Tmix_x060 as RWKV_Tmix |
| | | # self.att = RWKVLayer(args=args, layer_id=layer_id) |
| | | self.att = RWKV_Tmix(args, layer_id=layer_id) |
| | | |
| | | if args.get("init_rwkv", True): |
| | | print("init_rwkv") |
| | | nn.init.orthogonal_(self.att.receptance.weight, gain=1) |
| | | nn.init.orthogonal_(self.att.key.weight, gain=0.1) |
| | | nn.init.orthogonal_(self.att.value.weight, gain=1) |
| | | nn.init.orthogonal_(self.att.gate.weight, gain=0.1) |
| | | nn.init.zeros_(self.att.output.weight) |
| | | |
| | | self.ln0 = None |
| | | if layer_id == 0 and not args.get("ln0", True): |
| | |
| | | layer_id = 0 |
| | | scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7 |
| | | nn.init.constant_(self.ln0.weight, scale) |
| | | |
| | | self.layer_id = layer_id |
| | | self.args = args |
| | | |
| | |
| | | print("init_rwkv") |
| | | scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7 |
| | | nn.init.constant_(self.ln1.weight, scale) |
| | | |
| | | if args.get("datatype", "bf16") == "bf16": |
| | | self.att.to(torch.bfloat16) |
| | | # if self.ln1 is not None: |
| | | # self.ln1.to(torch.bfloat16) |
| | | |
| | | self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None |
| | | self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None |
| | |
| | | if self.layer_id == 0 and self.ln0 is not None: |
| | | x = self.ln0(x) |
| | | |
| | | if self.args.get("datatype", "bf16") == "bf16": |
| | | x = x.bfloat16() |
| | | if self.ln1 is None: |
| | | x = x + self.attn(x, mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0] |
| | | x = x + self.att(x, mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0] |
| | | else: |
| | | x = x + self.attn(self.ln1(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0] |
| | | x = x + self.att(self.ln1(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0] |
| | | if self.args.get("datatype", "bf16") == "bf16": |
| | | x = x.to(torch.float32) |
| | | |
| | | if self.cross_attn: |
| | | x = ( |