zhifu gao
2024-04-26 e971e000ad582c767ae44c9650470899f5bb46d0
funasr/models/conformer_rwkv/decoder.py
@@ -97,9 +97,7 @@
            from funasr.models.sense_voice.rwkv_v6 import RWKV_Tmix_x060 as RWKV_Tmix
        # self.attn = RWKVLayer(args=args, layer_id=layer_id)
        self.self_attn = RWKV_Tmix(args, layer_id=layer_id)
        if args.get("datatype", "bf16") == "bf16":
            self.self_attn.to(torch.bfloat16)
            # self.norm1.to(torch.bfloat16)
        self.args = args
        self.ln0 = None
        if self.layer_id == 0 and not args.get("ln0", True):
@@ -125,6 +123,10 @@
            nn.init.orthogonal_(self.self_attn.gate.weight, gain=0.1)
            nn.init.zeros_(self.self_attn.output.weight)
        if args.get("datatype", "bf16") == "bf16":
            self.self_attn.to(torch.bfloat16)
            # self.norm1.to(torch.bfloat16)
    def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
        """Compute decoded features.