From 4ace5a95b052d338947fc88809a440ccd55cf6b4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 16 十一月 2023 16:39:52 +0800
Subject: [PATCH] funasr pages

---
 funasr/modules/embedding.py |   24 +++++++++++++++++++++---
 1 files changed, 21 insertions(+), 3 deletions(-)

diff --git a/funasr/modules/embedding.py b/funasr/modules/embedding.py
index aaac80a..1995bbe 100644
--- a/funasr/modules/embedding.py
+++ b/funasr/modules/embedding.py
@@ -9,6 +9,7 @@
 import math
 import torch
 import torch.nn.functional as F
+from torch import einsum
 
 def _pre_hook(
     state_dict,
@@ -393,8 +394,9 @@
     def encode(self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32):
         batch_size = positions.size(0)
         positions = positions.type(dtype)
-        log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype)) / (depth / 2 - 1)
-        inv_timescales = torch.exp(torch.arange(depth / 2).type(dtype) * (-log_timescale_increment))
+        device = positions.device
+        log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype, device=device)) / (depth / 2 - 1)
+        inv_timescales = torch.exp(torch.arange(depth / 2, device=device).type(dtype) * (-log_timescale_increment))
         inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
         scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(inv_timescales, [1, 1, -1])
         encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
@@ -402,7 +404,7 @@
 
     def forward(self, x):
         batch_size, timesteps, input_dim = x.size()
-        positions = torch.arange(1, timesteps+1)[None, :]
+        positions = torch.arange(1, timesteps+1, device=x.device)[None, :]
         position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
 
         return x + position_encoding
@@ -509,3 +511,19 @@
         pos_enc = self.dropout(pos_enc)
 
         return pos_enc
+
+
+class ScaledSinuEmbedding(torch.nn.Module):
+    def __init__(self, dim):
+        super().__init__()
+        self.scale = torch.nn.Parameter(torch.ones(1,))
+        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+        self.register_buffer('inv_freq', inv_freq)
+
+    def forward(self, x):
+        n, device = x.shape[1], x.device
+        t = torch.arange(n, device = device).type_as(self.inv_freq)
+        sinu = einsum('i , j -> i j', t, self.inv_freq)
+        emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
+        return emb * self.scale
+

--
Gitblit v1.9.1