北念
2023-03-22 93d78edee3be55f71a2ab22cf79b881a21df8869
funasr/modules/embedding.py
@@ -405,4 +405,13 @@
        positions = torch.arange(1, timesteps+1)[None, :]
        position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
        return x + position_encoding
        return x + position_encoding
    def forward_chunk(self, x, cache=None):
        start_idx = 0
        batch_size, timesteps, input_dim = x.size()
        if cache is not None:
            start_idx = cache["start_idx"]
        positions = torch.arange(1, timesteps+start_idx+1)[None, :]
        position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
        return x + position_encoding[:, start_idx: start_idx + timesteps]