zhifu gao
2024-04-25 fc68b5ffe453235294a561737d8e84bb6c1689a4
funasr/models/sense_voice/rwkv_v6.py
@@ -244,7 +244,7 @@
        x = self.output(x * g)
        return x
    def forward(self, x):
    def forward(self, x, **kwargs):
        B, T, C = x.size()
        H = self.n_head
@@ -341,11 +341,14 @@
        self.ln1 = None
        if args.get("ln1", True):
            self.ln1 = nn.LayerNorm(args.n_embd)
        self.ln2 = nn.LayerNorm(args.n_embd)
        self.att = RWKV_Tmix_x060(args, layer_id)
        self.ffn = RWKV_CMix_x060(args, layer_id)
        self.ln2 = None
        self.ffn = None
        if args.get("use_rwkv_ffn", True):
            self.ln2 = nn.LayerNorm(args.n_embd)
            self.ffn = RWKV_CMix_x060(args, layer_id)
        if args.dropout > 0:
            self.drop0 = nn.Dropout(p=args.dropout)
@@ -364,11 +367,13 @@
            nn.init.zeros_(self.ffn.value.weight)
            nn.init.zeros_(self.ffn.receptance.weight)
            scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
            nn.init.constant_(self.ln2.weight, scale)
            if self.ln0 is not None:
                nn.init.constant_(self.ln0.weight, scale)
            if self.ln1 is not None:
                nn.init.constant_(self.ln1.weight, scale)
            if self.ln2 is not None:
                nn.init.constant_(self.ln2.weight, scale)
    def forward(self, x, x_emb=None, mask=None, **kwargs):
@@ -384,13 +389,15 @@
                x = x + self.att(x)
            else:
                x = x + self.att(self.ln1(x))
            x = x + self.ffn(self.ln2(x))
            if self.ffn is not None:
                x = x + self.ffn(self.ln2(x))
        else:
            if self.ln1 is None:
                x = self.drop0(x + self.att(x))
            else:
                x = self.drop0(x + self.att(self.ln1(x)))
            x = self.drop1(x + self.ffn(self.ln2(x)))
            if self.ffn is not None:
                x = self.drop1(x + self.ffn(self.ln2(x)))
        if args.get("datatype", "bf16") == "bf16":
            x = x.to(torch.float32)