| | |
| | | from funasr.models.sense_voice.rwkv_v6 import RWKV_Tmix_x060 as RWKV_Tmix |
| | | # self.attn = RWKVLayer(args=args, layer_id=layer_id) |
| | | self.self_attn = RWKV_Tmix(args, layer_id=layer_id) |
| | | if args.get("datatype", "bf16") == "bf16": |
| | | self.self_attn.to(torch.bfloat16) |
| | | # self.norm1.to(torch.bfloat16) |
| | | |
| | | self.args = args |
| | | self.ln0 = None |
| | | if self.layer_id == 0 and not args.get("ln0", True): |
| | |
| | | nn.init.orthogonal_(self.self_attn.gate.weight, gain=0.1) |
| | | nn.init.zeros_(self.self_attn.output.weight) |
| | | |
| | | if args.get("datatype", "bf16") == "bf16": |
| | | self.self_attn.to(torch.bfloat16) |
| | | # self.norm1.to(torch.bfloat16) |
| | | |
| | | def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None): |
| | | """Compute decoded features. |
| | | |