From fc68b5ffe453235294a561737d8e84bb6c1689a4 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 25 四月 2024 21:43:47 +0800
Subject: [PATCH] Dev gzf exp (#1661)
---
funasr/models/conformer_rwkv/decoder.py | 46 +++++++++++++++++++++++++++++++++++++++-------
1 files changed, 39 insertions(+), 7 deletions(-)
diff --git a/funasr/models/conformer_rwkv/decoder.py b/funasr/models/conformer_rwkv/decoder.py
index 90e56e5..5e2ac12 100644
--- a/funasr/models/conformer_rwkv/decoder.py
+++ b/funasr/models/conformer_rwkv/decoder.py
@@ -29,6 +29,11 @@
from funasr.register import tables
+class LayerNorm(nn.LayerNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
class DecoderLayer(nn.Module):
"""Single decoder layer module.
@@ -54,7 +59,7 @@
def __init__(
self,
size,
- self_attn,
+ # self_attn,
src_attn,
feed_forward,
dropout_rate,
@@ -62,11 +67,12 @@
concat_after=False,
layer_id=None,
args={},
+ **kwargs,
):
"""Construct an DecoderLayer object."""
super(DecoderLayer, self).__init__()
self.size = size
- self.self_attn = self_attn.to(torch.bfloat16)
+ # self.self_attn = self_attn.to(torch.bfloat16)
self.src_attn = src_attn
self.feed_forward = feed_forward
self.norm1 = LayerNorm(size)
@@ -79,6 +85,22 @@
self.concat_linear1 = nn.Linear(size + size, size)
self.concat_linear2 = nn.Linear(size + size, size)
self.layer_id = layer_id
+
+ if args.get("version", "v4") == "v4":
+ from funasr.models.sense_voice.rwkv_v4 import RWKVLayer
+ from funasr.models.sense_voice.rwkv_v4 import RWKV_TimeMix as RWKV_Tmix
+ elif args.get("version", "v5") == "v5":
+ from funasr.models.sense_voice.rwkv_v5 import RWKVLayer
+ from funasr.models.sense_voice.rwkv_v5 import RWKV_Tmix_x052 as RWKV_Tmix
+ else:
+ from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
+ from funasr.models.sense_voice.rwkv_v6 import RWKV_Tmix_x060 as RWKV_Tmix
+ # self.attn = RWKVLayer(args=args, layer_id=layer_id)
+ self.self_attn = RWKV_Tmix(args, layer_id=layer_id)
+ if args.get("datatype", "bf16") == "bf16":
+ self.self_attn.to(torch.bfloat16)
+ # self.norm1.to(torch.bfloat16)
+ self.args = args
self.ln0 = None
if self.layer_id == 0 and not args.get("ln0", True):
self.ln0 = LayerNorm(args.n_embd)
@@ -93,7 +115,15 @@
print("init_rwkv")
scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
nn.init.constant_(self.norm1.weight, scale)
- nn.init.constant_(self.self_attn.ln2.weight, scale)
+ # nn.init.constant_(self.self_attn.ln2.weight, scale)
+
+ if args.get("init_rwkv", True):
+ print("init_rwkv")
+ nn.init.orthogonal_(self.self_attn.receptance.weight, gain=1)
+ nn.init.orthogonal_(self.self_attn.key.weight, gain=0.1)
+ nn.init.orthogonal_(self.self_attn.value.weight, gain=1)
+ nn.init.orthogonal_(self.self_attn.gate.weight, gain=0.1)
+ nn.init.zeros_(self.self_attn.output.weight)
def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
"""Compute decoded features.
@@ -117,6 +147,8 @@
if self.layer_id == 0 and self.ln0 is not None:
tgt = self.ln0(tgt)
+ if self.args.get("datatype", "bf16") == "bf16":
+ tgt = tgt.bfloat16()
residual = tgt
tgt = self.norm1(tgt)
@@ -132,7 +164,8 @@
x = residual + self.dropout(self.self_attn(tgt, mask=tgt_q_mask))
x = x[:, -1, :]
-
+ if self.args.get("datatype", "bf16") == "bf16":
+ x = x.to(torch.float32)
# x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
residual = x
@@ -370,17 +403,16 @@
pos_enc_class=pos_enc_class,
normalize_before=normalize_before,
)
- from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
+ # from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
rwkv_cfg = kwargs.get("rwkv_cfg", {})
args = OmegaConf.create(rwkv_cfg)
- # self.attn = RWKVLayer(args=args, layer_id=layer_id)
+
attention_dim = encoder_output_size
self.decoders = repeat(
num_blocks,
lambda lnum: DecoderLayer(
attention_dim,
- RWKVLayer(args=args, layer_id=lnum),
MultiHeadedAttention(attention_heads, attention_dim, src_attention_dropout_rate),
PositionwiseFeedForward(attention_dim, linear_units, dropout_rate),
dropout_rate,
--
Gitblit v1.9.1