aky15
2023-11-01 4e0404e04ed890717ead276e52c927a820326ec1
funasr/models/encoder/rwkv_encoder.py
@@ -113,11 +113,12 @@
        x = self.embed_norm(x)
        olens = mask.eq(0).sum(1)
        for block in self.rwkv_blocks:
            x, _ = block(x)
        # for streaming inference
        # xs_pad = self.rwkv_infer(xs_pad)
        # for training
        # for block in self.rwkv_blocks:
        #     x, _ = block(x)
        # for streaming inference
        x = self.rwkv_infer(x)
        x = self.final_norm(x)
        if self.time_reduction_factor > 1:
@@ -136,9 +137,9 @@
        state = [
            torch.zeros(
                (batch_size, 1, hidden_sizes[i], self.num_rwkv_blocks),
                (batch_size, 1, hidden_sizes[i], self.num_blocks),
                dtype=torch.float32,
                device=self.device,
                device=xs_pad.device,
            )
            for i in range(5)
        ]
@@ -151,5 +152,5 @@
            for idx, block in enumerate(self.rwkv_blocks):
                x_t, state = block(x_t, state=state)
            xs_out.append(x_t)
        xs_out = torch.stack(xs_out, dim=1)
        xs_out = torch.cat(xs_out, dim=1)
        return xs_out