From ea2c102e6162c924c682aabfe8a052ce9a766a4d Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期四, 10 八月 2023 20:17:53 +0800
Subject: [PATCH] Merge pull request #832 from alibaba-damo-academy/dev_lhn
---
funasr/modules/embedding.py | 17 +++++++++++++++++
1 files changed, 17 insertions(+), 0 deletions(-)
diff --git a/funasr/modules/embedding.py b/funasr/modules/embedding.py
index 374eba4..1995bbe 100644
--- a/funasr/modules/embedding.py
+++ b/funasr/modules/embedding.py
@@ -9,6 +9,7 @@
import math
import torch
import torch.nn.functional as F
+from torch import einsum
def _pre_hook(
state_dict,
@@ -510,3 +511,19 @@
pos_enc = self.dropout(pos_enc)
return pos_enc
+
+
+class ScaledSinuEmbedding(torch.nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.scale = torch.nn.Parameter(torch.ones(1,))
+ inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer('inv_freq', inv_freq)
+
+ def forward(self, x):
+ n, device = x.shape[1], x.device
+ t = torch.arange(n, device = device).type_as(self.inv_freq)
+ sinu = einsum('i , j -> i j', t, self.inv_freq)
+ emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
+ return emb * self.scale
+
--
Gitblit v1.9.1