From 5eabdd444ea07038b6e814a022a212e8a87f6a9a Mon Sep 17 00:00:00 2001
From: 夜雨飘零 <yeyupiaoling@foxmail.com>
Date: 星期四, 30 十一月 2023 00:44:17 +0800
Subject: [PATCH] 修复为支持新版本的热词 (#1137)
---
funasr/models/encoder/rwkv_encoder.py | 15 ++++++++-------
1 files changed, 8 insertions(+), 7 deletions(-)
diff --git a/funasr/models/encoder/rwkv_encoder.py b/funasr/models/encoder/rwkv_encoder.py
index 8a33520..40151bf 100644
--- a/funasr/models/encoder/rwkv_encoder.py
+++ b/funasr/models/encoder/rwkv_encoder.py
@@ -113,11 +113,12 @@
x = self.embed_norm(x)
olens = mask.eq(0).sum(1)
- for block in self.rwkv_blocks:
- x, _ = block(x)
- # for streaming inference
- # xs_pad = self.rwkv_infer(xs_pad)
+ # for training
+ # for block in self.rwkv_blocks:
+ # x, _ = block(x)
+ # for streaming inference
+ x = self.rwkv_infer(x)
x = self.final_norm(x)
if self.time_reduction_factor > 1:
@@ -136,9 +137,9 @@
state = [
torch.zeros(
- (batch_size, 1, hidden_sizes[i], self.num_rwkv_blocks),
+ (batch_size, 1, hidden_sizes[i], self.num_blocks),
dtype=torch.float32,
- device=self.device,
+ device=xs_pad.device,
)
for i in range(5)
]
@@ -151,5 +152,5 @@
for idx, block in enumerate(self.rwkv_blocks):
x_t, state = block(x_t, state=state)
xs_out.append(x_t)
- xs_out = torch.stack(xs_out, dim=1)
+ xs_out = torch.cat(xs_out, dim=1)
return xs_out
--
Gitblit v1.9.1