From fc68b5ffe453235294a561737d8e84bb6c1689a4 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 25 四月 2024 21:43:47 +0800
Subject: [PATCH] Dev gzf exp (#1661)

---
 funasr/models/sense_voice/rwkv_v6.py |   19 +++++++++++++------
 1 files changed, 13 insertions(+), 6 deletions(-)

diff --git a/funasr/models/sense_voice/rwkv_v6.py b/funasr/models/sense_voice/rwkv_v6.py
index 36269a1..b91d47a 100644
--- a/funasr/models/sense_voice/rwkv_v6.py
+++ b/funasr/models/sense_voice/rwkv_v6.py
@@ -244,7 +244,7 @@
         x = self.output(x * g)
         return x
 
-    def forward(self, x):
+    def forward(self, x, **kwargs):
         B, T, C = x.size()
         H = self.n_head
 
@@ -341,11 +341,14 @@
         self.ln1 = None
         if args.get("ln1", True):
             self.ln1 = nn.LayerNorm(args.n_embd)
-        self.ln2 = nn.LayerNorm(args.n_embd)
 
         self.att = RWKV_Tmix_x060(args, layer_id)
 
-        self.ffn = RWKV_CMix_x060(args, layer_id)
+        self.ln2 = None
+        self.ffn = None
+        if args.get("use_rwkv_ffn", True):
+            self.ln2 = nn.LayerNorm(args.n_embd)
+            self.ffn = RWKV_CMix_x060(args, layer_id)
 
         if args.dropout > 0:
             self.drop0 = nn.Dropout(p=args.dropout)
@@ -364,11 +367,13 @@
             nn.init.zeros_(self.ffn.value.weight)
             nn.init.zeros_(self.ffn.receptance.weight)
             scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
-            nn.init.constant_(self.ln2.weight, scale)
+
             if self.ln0 is not None:
                 nn.init.constant_(self.ln0.weight, scale)
             if self.ln1 is not None:
                 nn.init.constant_(self.ln1.weight, scale)
+            if self.ln2 is not None:
+                nn.init.constant_(self.ln2.weight, scale)
 
     def forward(self, x, x_emb=None, mask=None, **kwargs):
 
@@ -384,13 +389,15 @@
                 x = x + self.att(x)
             else:
                 x = x + self.att(self.ln1(x))
-            x = x + self.ffn(self.ln2(x))
+            if self.ffn is not None:
+                x = x + self.ffn(self.ln2(x))
         else:
             if self.ln1 is None:
                 x = self.drop0(x + self.att(x))
             else:
                 x = self.drop0(x + self.att(self.ln1(x)))
-            x = self.drop1(x + self.ffn(self.ln2(x)))
+            if self.ffn is not None:
+                x = self.drop1(x + self.ffn(self.ln2(x)))
 
         if args.get("datatype", "bf16") == "bf16":
             x = x.to(torch.float32)

--
Gitblit v1.9.1