From 4e0404e04ed890717ead276e52c927a820326ec1 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 01 十一月 2023 16:47:13 +0800
Subject: [PATCH] fix rwkv infer bugs

---
 funasr/models/encoder/rwkv_encoder.py |   15 ++++++++-------
 1 files changed, 8 insertions(+), 7 deletions(-)

diff --git a/funasr/models/encoder/rwkv_encoder.py b/funasr/models/encoder/rwkv_encoder.py
index 8a33520..40151bf 100644
--- a/funasr/models/encoder/rwkv_encoder.py
+++ b/funasr/models/encoder/rwkv_encoder.py
@@ -113,11 +113,12 @@
         x = self.embed_norm(x)
         olens = mask.eq(0).sum(1)
 
-        for block in self.rwkv_blocks:
-            x, _ = block(x)
-        # for streaming inference
-        # xs_pad = self.rwkv_infer(xs_pad)
+        # for training
+        # for block in self.rwkv_blocks:
+        #     x, _ = block(x)
 
+        # for streaming inference
+        x = self.rwkv_infer(x)
         x = self.final_norm(x)
 
         if self.time_reduction_factor > 1:
@@ -136,9 +137,9 @@
 
         state = [
             torch.zeros(
-                (batch_size, 1, hidden_sizes[i], self.num_rwkv_blocks),
+                (batch_size, 1, hidden_sizes[i], self.num_blocks),
                 dtype=torch.float32,
-                device=self.device,
+                device=xs_pad.device,
             )
             for i in range(5)
         ]
@@ -151,5 +152,5 @@
             for idx, block in enumerate(self.rwkv_blocks):
                 x_t, state = block(x_t, state=state)
             xs_out.append(x_t)
-        xs_out = torch.stack(xs_out, dim=1)
+        xs_out = torch.cat(xs_out, dim=1)
         return xs_out

--
Gitblit v1.9.1