游雁
2024-01-14 8912e0696af069de47646fdb8a9d9c4e086e88b3
funasr/models/rwkv_bat/rwkv_encoder.py
@@ -113,12 +113,12 @@
        x = self.embed_norm(x)
        olens = mask.eq(0).sum(1)
        # for training
        # for block in self.rwkv_blocks:
        #     x, _ = block(x)
        # for streaming inference
        x = self.rwkv_infer(x)
        if self.training:
            for block in self.rwkv_blocks:
                x, _ = block(x)
        else:
            x = self.rwkv_infer(x)
        x = self.final_norm(x)
        if self.time_reduction_factor > 1: