funasr/modules/embedding.py
@@ -406,3 +406,12 @@ position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device) return x + position_encoding def forward_chunk(self, x, cache=None): start_idx = 0 batch_size, timesteps, input_dim = x.size() if cache is not None: start_idx = cache["start_idx"] positions = torch.arange(1, timesteps+start_idx+1)[None, :] position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device) return x + position_encoding[:, start_idx: start_idx + timesteps]