From ea2c102e6162c924c682aabfe8a052ce9a766a4d Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期四, 10 八月 2023 20:17:53 +0800
Subject: [PATCH] Merge pull request #832 from alibaba-damo-academy/dev_lhn

---
 funasr/modules/embedding.py |   17 +++++++++++++++++
 1 files changed, 17 insertions(+), 0 deletions(-)

diff --git a/funasr/modules/embedding.py b/funasr/modules/embedding.py
index 374eba4..1995bbe 100644
--- a/funasr/modules/embedding.py
+++ b/funasr/modules/embedding.py
@@ -9,6 +9,7 @@
 import math
 import torch
 import torch.nn.functional as F
+from torch import einsum
 
 def _pre_hook(
     state_dict,
@@ -510,3 +511,19 @@
         pos_enc = self.dropout(pos_enc)
 
         return pos_enc
+
+
+class ScaledSinuEmbedding(torch.nn.Module):
+    def __init__(self, dim):
+        super().__init__()
+        self.scale = torch.nn.Parameter(torch.ones(1,))
+        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+        self.register_buffer('inv_freq', inv_freq)
+
+    def forward(self, x):
+        n, device = x.shape[1], x.device
+        t = torch.arange(n, device = device).type_as(self.inv_freq)
+        sinu = einsum('i , j -> i j', t, self.inv_freq)
+        emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
+        return emb * self.scale
+

--
Gitblit v1.9.1