haoneng.lhn
2023-04-27 7584bbd6f3e321cc8bc970739a7cfce29ffcc18b
funasr/modules/embedding.py
@@ -425,21 +425,14 @@
        return encoding.type(dtype)
    def forward(self, x, cache=None):
        start_idx = 0
        pad_left = 0
        pad_right = 0
        batch_size, timesteps, input_dim = x.size()
        start_idx = 0
        if cache is not None:
            start_idx = cache["start_idx"]
            pad_left = cache["left"]
            pad_right = cache["right"]
            cache["start_idx"] += timesteps
        positions = torch.arange(1, timesteps+start_idx+1)[None, :]
        position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
        outputs = x + position_encoding[:, start_idx: start_idx + timesteps]
        outputs = outputs.transpose(1, 2)
        outputs = F.pad(outputs, (pad_left, pad_right))
        outputs = outputs.transpose(1, 2)
        return outputs
        return x + position_encoding[:, start_idx: start_idx + timesteps]
class StreamingRelPositionalEncoding(torch.nn.Module):
    """Relative positional encoding.