夜雨飘零
2024-02-03 2cf4084b23db9bd9e8ce4db76d0628ef6655ed71
funasr/models/rwkv_bat/rwkv_encoder.py
@@ -113,12 +113,12 @@
        x = self.embed_norm(x)
        olens = mask.eq(0).sum(1)
        # for training
        # for block in self.rwkv_blocks:
        #     x, _ = block(x)
        # for streaming inference
        x = self.rwkv_infer(x)
        if self.training:
            for block in self.rwkv_blocks:
                x, _ = block(x)
        else:
            x = self.rwkv_infer(x)
        x = self.final_norm(x)
        if self.time_reduction_factor > 1: