hnluo
2023-08-10 ea2c102e6162c924c682aabfe8a052ce9a766a4d
funasr/modules/embedding.py
@@ -9,6 +9,7 @@
import math
import torch
import torch.nn.functional as F
from torch import einsum
def _pre_hook(
    state_dict,
@@ -510,3 +511,19 @@
        pos_enc = self.dropout(pos_enc)
        return pos_enc
class ScaledSinuEmbedding(torch.nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.scale = torch.nn.Parameter(torch.ones(1,))
        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
    def forward(self, x):
        n, device = x.shape[1], x.device
        t = torch.arange(n, device = device).type_as(self.inv_freq)
        sinu = einsum('i , j -> i j', t, self.inv_freq)
        emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
        return emb * self.scale