仁迷
2023-03-13 3762d21300e1f3fa3e0cb1e67545227e6dcec3de
funasr/modules/embedding.py
@@ -405,4 +405,13 @@
        positions = torch.arange(1, timesteps+1)[None, :]
        position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
        return x + position_encoding
        return x + position_encoding
    def forward_chunk(self, x, cache=None):
        start_idx = 0
        batch_size, timesteps, input_dim = x.size()
        if cache is not None:
            start_idx = cache["start_idx"]
        positions = torch.arange(1, timesteps+start_idx+1)[None, :]
        position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
        return x + position_encoding[:, start_idx: start_idx + timesteps]