shixian.shi
2024-01-15 55c09aeaa25b4bb88a50e09ba68fa6ff00a6d676
funasr/models/rwkv_bat/rwkv_encoder.py
@@ -113,12 +113,12 @@
        x = self.embed_norm(x)
        olens = mask.eq(0).sum(1)
        # for training
        # for block in self.rwkv_blocks:
        #     x, _ = block(x)
        # for streaming inference
        x = self.rwkv_infer(x)
        if self.training:
            for block in self.rwkv_blocks:
                x, _ = block(x)
        else:
            x = self.rwkv_infer(x)
        x = self.final_norm(x)
        if self.time_reduction_factor > 1: