From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/models/transformer/embedding.py | 88 +++++++++++++++++++++++---------------------
1 files changed, 46 insertions(+), 42 deletions(-)
diff --git a/funasr/models/transformer/embedding.py b/funasr/models/transformer/embedding.py
index 1995bbe..7c86dab 100644
--- a/funasr/models/transformer/embedding.py
+++ b/funasr/models/transformer/embedding.py
@@ -11,6 +11,7 @@
import torch.nn.functional as F
from torch import einsum
+
def _pre_hook(
state_dict,
prefix,
@@ -64,9 +65,7 @@
return
pe = torch.zeros(x.size(1), self.d_model)
if self.reverse:
- position = torch.arange(
- x.size(1) - 1, -1, -1.0, dtype=torch.float32
- ).unsqueeze(1)
+ position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
@@ -170,9 +169,7 @@
if self.gamma is None:
self.gamma = self.d_model // 2
- assert (
- d_model % 2 == 0
- ), "d_model should be divisible by two in order to use this layer."
+ assert d_model % 2 == 0, "d_model should be divisible by two in order to use this layer."
self.w_r = torch.nn.Parameter(torch.empty(1, d_model // 2))
self._reset() # init the weights
@@ -185,9 +182,7 @@
)
def _reset(self):
- self.w_r.data = torch.normal(
- 0, (1 / math.sqrt(self.gamma)), (1, self.d_model // 2)
- )
+ self.w_r.data = torch.normal(0, (1 / math.sqrt(self.gamma)), (1, self.d_model // 2))
def extend_pe(self, x):
"""Reset the positional encodings."""
@@ -384,45 +379,57 @@
x = x * self.xscale + self.pe[:, start_idx : start_idx + x.size(1)]
return self.dropout(x)
-class SinusoidalPositionEncoder(torch.nn.Module):
- '''
- '''
+class SinusoidalPositionEncoder(torch.nn.Module):
+ """ """
+
def __int__(self, d_model=80, dropout_rate=0.1):
pass
- def encode(self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32):
+ def encode(
+ self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32
+ ):
batch_size = positions.size(0)
positions = positions.type(dtype)
device = positions.device
- log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype, device=device)) / (depth / 2 - 1)
- inv_timescales = torch.exp(torch.arange(depth / 2, device=device).type(dtype) * (-log_timescale_increment))
+ log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype, device=device)) / (
+ depth / 2 - 1
+ )
+ inv_timescales = torch.exp(
+ torch.arange(depth / 2, device=device).type(dtype) * (-log_timescale_increment)
+ )
inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
- scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(inv_timescales, [1, 1, -1])
+ scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(
+ inv_timescales, [1, 1, -1]
+ )
encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
return encoding.type(dtype)
def forward(self, x):
batch_size, timesteps, input_dim = x.size()
- positions = torch.arange(1, timesteps+1, device=x.device)[None, :]
+ positions = torch.arange(1, timesteps + 1, device=x.device)[None, :]
position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
return x + position_encoding
-class StreamSinusoidalPositionEncoder(torch.nn.Module):
- '''
- '''
+class StreamSinusoidalPositionEncoder(torch.nn.Module):
+ """ """
+
def __int__(self, d_model=80, dropout_rate=0.1):
pass
- def encode(self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32):
+ def encode(
+ self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32
+ ):
batch_size = positions.size(0)
positions = positions.type(dtype)
log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype)) / (depth / 2 - 1)
inv_timescales = torch.exp(torch.arange(depth / 2).type(dtype) * (-log_timescale_increment))
inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
- scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(inv_timescales, [1, 1, -1])
+ scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(
+ inv_timescales, [1, 1, -1]
+ )
encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
return encoding.type(dtype)
@@ -432,9 +439,10 @@
if cache is not None:
start_idx = cache["start_idx"]
cache["start_idx"] += timesteps
- positions = torch.arange(1, timesteps+start_idx+1)[None, :]
+ positions = torch.arange(1, timesteps + start_idx + 1)[None, :]
position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
- return x + position_encoding[:, start_idx: start_idx + timesteps]
+ return x + position_encoding[:, start_idx : start_idx + timesteps]
+
class StreamingRelPositionalEncoding(torch.nn.Module):
"""Relative positional encoding.
@@ -444,9 +452,7 @@
dropout_rate: Dropout rate.
"""
- def __init__(
- self, size: int, dropout_rate: float = 0.0, max_len: int = 5000
- ) -> None:
+ def __init__(self, size: int, dropout_rate: float = 0.0, max_len: int = 5000) -> None:
"""Construct a RelativePositionalEncoding object."""
super().__init__()
@@ -477,8 +483,7 @@
position = torch.arange(0, time1, dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
- torch.arange(0, self.size, 2, dtype=torch.float32)
- * -(math.log(10000.0) / self.size)
+ torch.arange(0, self.size, 2, dtype=torch.float32) * -(math.log(10000.0) / self.size)
)
pe_positive[:, 0::2] = torch.sin(position * div_term)
@@ -489,9 +494,7 @@
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
pe_negative = pe_negative[1:].unsqueeze(0)
- self.pe = torch.cat([pe_positive, pe_negative], dim=1).to(
- dtype=x.dtype, device=x.device
- )
+ self.pe = torch.cat([pe_positive, pe_negative], dim=1).to(dtype=x.dtype, device=x.device)
def forward(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
"""Compute positional encoding.
@@ -505,9 +508,7 @@
time1 = x.size(1) + left_context
- pos_enc = self.pe[
- :, self.pe.size(1) // 2 - time1 + 1 : self.pe.size(1) // 2 + x.size(1)
- ]
+ pos_enc = self.pe[:, self.pe.size(1) // 2 - time1 + 1 : self.pe.size(1) // 2 + x.size(1)]
pos_enc = self.dropout(pos_enc)
return pos_enc
@@ -516,14 +517,17 @@
class ScaledSinuEmbedding(torch.nn.Module):
def __init__(self, dim):
super().__init__()
- self.scale = torch.nn.Parameter(torch.ones(1,))
- inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
- self.register_buffer('inv_freq', inv_freq)
+ self.scale = torch.nn.Parameter(
+ torch.ones(
+ 1,
+ )
+ )
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
def forward(self, x):
n, device = x.shape[1], x.device
- t = torch.arange(n, device = device).type_as(self.inv_freq)
- sinu = einsum('i , j -> i j', t, self.inv_freq)
- emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
+ t = torch.arange(n, device=device).type_as(self.inv_freq)
+ sinu = einsum("i , j -> i j", t, self.inv_freq)
+ emb = torch.cat((sinu.sin(), sinu.cos()), dim=-1)
return emb * self.scale
-
--
Gitblit v1.9.1