From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords
---
funasr/models/conformer_rwkv/decoder.py | 8 +++++---
1 files changed, 5 insertions(+), 3 deletions(-)
diff --git a/funasr/models/conformer_rwkv/decoder.py b/funasr/models/conformer_rwkv/decoder.py
index 5e2ac12..4c41049 100644
--- a/funasr/models/conformer_rwkv/decoder.py
+++ b/funasr/models/conformer_rwkv/decoder.py
@@ -97,9 +97,7 @@
from funasr.models.sense_voice.rwkv_v6 import RWKV_Tmix_x060 as RWKV_Tmix
# self.attn = RWKVLayer(args=args, layer_id=layer_id)
self.self_attn = RWKV_Tmix(args, layer_id=layer_id)
- if args.get("datatype", "bf16") == "bf16":
- self.self_attn.to(torch.bfloat16)
- # self.norm1.to(torch.bfloat16)
+
self.args = args
self.ln0 = None
if self.layer_id == 0 and not args.get("ln0", True):
@@ -125,6 +123,10 @@
nn.init.orthogonal_(self.self_attn.gate.weight, gain=0.1)
nn.init.zeros_(self.self_attn.output.weight)
+ if args.get("datatype", "bf16") == "bf16":
+ self.self_attn.to(torch.bfloat16)
+ # self.norm1.to(torch.bfloat16)
+
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
"""Compute decoded features.
--
Gitblit v1.9.1