Yabin Li
2023-04-06 0eacba96a12d5c0dea89c4533ca68b40decd8e9f
funasr/modules/embedding.py
@@ -8,7 +8,7 @@
import math
import torch
import torch.nn.functional as F
def _pre_hook(
    state_dict,
@@ -409,9 +409,18 @@
    def forward_chunk(self, x, cache=None):
        start_idx = 0
        pad_left = 0
        pad_right = 0
        batch_size, timesteps, input_dim = x.size()
        if cache is not None:
            start_idx = cache["start_idx"]
            pad_left = cache["left"]
            pad_right = cache["right"]
        positions = torch.arange(1, timesteps+start_idx+1)[None, :]
        position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
        return x + position_encoding[:, start_idx: start_idx + timesteps]
        outputs = x + position_encoding[:, start_idx: start_idx + timesteps]
        outputs = outputs.transpose(1,2)
        outputs = F.pad(outputs, (pad_left, pad_right))
        outputs = outputs.transpose(1,2)
        return outputs