zhifu gao
2023-03-17 03d4ce829814b4a7f57235fda049351c524ba32b
funasr/modules/embedding.py
@@ -405,4 +405,13 @@
        positions = torch.arange(1, timesteps+1)[None, :]
        position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
        return x + position_encoding
        return x + position_encoding
    def forward_chunk(self, x, cache=None):
        start_idx = 0
        batch_size, timesteps, input_dim = x.size()
        if cache is not None:
            start_idx = cache["start_idx"]
        positions = torch.arange(1, timesteps+start_idx+1)[None, :]
        position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
        return x + position_encoding[:, start_idx: start_idx + timesteps]