haoneng.lhn
2023-03-29 d0d8684b964f06ab81279fa11a3725aaff01161c
funasr/modules/embedding.py
@@ -8,7 +8,7 @@
import math
import torch
import torch.nn.functional as F
def _pre_hook(
    state_dict,
@@ -409,9 +409,18 @@
    def forward_chunk(self, x, cache=None):
        start_idx = 0
        pad_left = 0
        pad_right = 0
        batch_size, timesteps, input_dim = x.size()
        if cache is not None:
            start_idx = cache["start_idx"]
            pad_left = cache["left"]
            pad_right = cache["right"]
        positions = torch.arange(1, timesteps+start_idx+1)[None, :]
        position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
        return x + position_encoding[:, start_idx: start_idx + timesteps]
        outputs = x + position_encoding[:, start_idx: start_idx + timesteps]
        outputs = outputs.transpose(1,2)
        outputs = F.pad(outputs, (pad_left, pad_right))
        outputs = outputs.transpose(1,2)
        return outputs