From fc68b5ffe453235294a561737d8e84bb6c1689a4 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 25 四月 2024 21:43:47 +0800
Subject: [PATCH] Dev gzf exp (#1661)
---
funasr/models/sense_voice/rwkv_v6.py | 19 +++++++++++++------
1 files changed, 13 insertions(+), 6 deletions(-)
diff --git a/funasr/models/sense_voice/rwkv_v6.py b/funasr/models/sense_voice/rwkv_v6.py
index 36269a1..b91d47a 100644
--- a/funasr/models/sense_voice/rwkv_v6.py
+++ b/funasr/models/sense_voice/rwkv_v6.py
@@ -244,7 +244,7 @@
x = self.output(x * g)
return x
- def forward(self, x):
+ def forward(self, x, **kwargs):
B, T, C = x.size()
H = self.n_head
@@ -341,11 +341,14 @@
self.ln1 = None
if args.get("ln1", True):
self.ln1 = nn.LayerNorm(args.n_embd)
- self.ln2 = nn.LayerNorm(args.n_embd)
self.att = RWKV_Tmix_x060(args, layer_id)
- self.ffn = RWKV_CMix_x060(args, layer_id)
+ self.ln2 = None
+ self.ffn = None
+ if args.get("use_rwkv_ffn", True):
+ self.ln2 = nn.LayerNorm(args.n_embd)
+ self.ffn = RWKV_CMix_x060(args, layer_id)
if args.dropout > 0:
self.drop0 = nn.Dropout(p=args.dropout)
@@ -364,11 +367,13 @@
nn.init.zeros_(self.ffn.value.weight)
nn.init.zeros_(self.ffn.receptance.weight)
scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
- nn.init.constant_(self.ln2.weight, scale)
+
if self.ln0 is not None:
nn.init.constant_(self.ln0.weight, scale)
if self.ln1 is not None:
nn.init.constant_(self.ln1.weight, scale)
+ if self.ln2 is not None:
+ nn.init.constant_(self.ln2.weight, scale)
def forward(self, x, x_emb=None, mask=None, **kwargs):
@@ -384,13 +389,15 @@
x = x + self.att(x)
else:
x = x + self.att(self.ln1(x))
- x = x + self.ffn(self.ln2(x))
+ if self.ffn is not None:
+ x = x + self.ffn(self.ln2(x))
else:
if self.ln1 is None:
x = self.drop0(x + self.att(x))
else:
x = self.drop0(x + self.att(self.ln1(x)))
- x = self.drop1(x + self.ffn(self.ln2(x)))
+ if self.ffn is not None:
+ x = self.drop1(x + self.ffn(self.ln2(x)))
if args.get("datatype", "bf16") == "bf16":
x = x.to(torch.float32)
--
Gitblit v1.9.1