游雁
2023-03-16 74464315c168aae0d7b9c494d3351dc2594996ae
funasr/modules/embedding.py
@@ -405,4 +405,13 @@
        positions = torch.arange(1, timesteps+1)[None, :]
        position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
        return x + position_encoding
        return x + position_encoding
    def forward_chunk(self, x, cache=None):
        start_idx = 0
        batch_size, timesteps, input_dim = x.size()
        if cache is not None:
            start_idx = cache["start_idx"]
        positions = torch.arange(1, timesteps+start_idx+1)[None, :]
        position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
        return x + position_encoding[:, start_idx: start_idx + timesteps]