shixian.shi
2024-01-16 b7cb19b01a1454f7a1388e24dcd4e10fc654bd7c
funasr/models/rwkv_bat/rwkv_encoder.py
@@ -113,12 +113,12 @@
        x = self.embed_norm(x)
        olens = mask.eq(0).sum(1)
        # for training
        # for block in self.rwkv_blocks:
        #     x, _ = block(x)
        # for streaming inference
        x = self.rwkv_infer(x)
        if self.training:
            for block in self.rwkv_blocks:
                x, _ = block(x)
        else:
            x = self.rwkv_infer(x)
        x = self.final_norm(x)
        if self.time_reduction_factor > 1: