funasr/models/rwkv_bat/rwkv_encoder.py
@@ -113,12 +113,12 @@ x = self.embed_norm(x) olens = mask.eq(0).sum(1) # for training # for block in self.rwkv_blocks: # x, _ = block(x) # for streaming inference x = self.rwkv_infer(x) if self.training: for block in self.rwkv_blocks: x, _ = block(x) else: x = self.rwkv_infer(x) x = self.final_norm(x) if self.time_reduction_factor > 1: