游雁
2024-01-16 bbbf17e4d97ff155049c424af4e96bfded9089b1
funasr/models/rwkv_bat/rwkv_encoder.py
@@ -113,12 +113,12 @@
        x = self.embed_norm(x)
        olens = mask.eq(0).sum(1)
        # for training
        # for block in self.rwkv_blocks:
        #     x, _ = block(x)
        # for streaming inference
        x = self.rwkv_infer(x)
        if self.training:
            for block in self.rwkv_blocks:
                x, _ = block(x)
        else:
            x = self.rwkv_infer(x)
        x = self.final_norm(x)
        if self.time_reduction_factor > 1: