From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/transformer/embedding.py |   88 +++++++++++++++++++++++---------------------
 1 files changed, 46 insertions(+), 42 deletions(-)

diff --git a/funasr/models/transformer/embedding.py b/funasr/models/transformer/embedding.py
index 1995bbe..7c86dab 100644
--- a/funasr/models/transformer/embedding.py
+++ b/funasr/models/transformer/embedding.py
@@ -11,6 +11,7 @@
 import torch.nn.functional as F
 from torch import einsum
 
+
 def _pre_hook(
     state_dict,
     prefix,
@@ -64,9 +65,7 @@
                 return
         pe = torch.zeros(x.size(1), self.d_model)
         if self.reverse:
-            position = torch.arange(
-                x.size(1) - 1, -1, -1.0, dtype=torch.float32
-            ).unsqueeze(1)
+            position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
         else:
             position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
         div_term = torch.exp(
@@ -170,9 +169,7 @@
         if self.gamma is None:
             self.gamma = self.d_model // 2
 
-        assert (
-            d_model % 2 == 0
-        ), "d_model should be divisible by two in order to use this layer."
+        assert d_model % 2 == 0, "d_model should be divisible by two in order to use this layer."
         self.w_r = torch.nn.Parameter(torch.empty(1, d_model // 2))
         self._reset()  # init the weights
 
@@ -185,9 +182,7 @@
             )
 
     def _reset(self):
-        self.w_r.data = torch.normal(
-            0, (1 / math.sqrt(self.gamma)), (1, self.d_model // 2)
-        )
+        self.w_r.data = torch.normal(0, (1 / math.sqrt(self.gamma)), (1, self.d_model // 2))
 
     def extend_pe(self, x):
         """Reset the positional encodings."""
@@ -384,45 +379,57 @@
         x = x * self.xscale + self.pe[:, start_idx : start_idx + x.size(1)]
         return self.dropout(x)
 
-class SinusoidalPositionEncoder(torch.nn.Module):
-    '''
 
-    '''
+class SinusoidalPositionEncoder(torch.nn.Module):
+    """ """
+
     def __int__(self, d_model=80, dropout_rate=0.1):
         pass
 
-    def encode(self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32):
+    def encode(
+        self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32
+    ):
         batch_size = positions.size(0)
         positions = positions.type(dtype)
         device = positions.device
-        log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype, device=device)) / (depth / 2 - 1)
-        inv_timescales = torch.exp(torch.arange(depth / 2, device=device).type(dtype) * (-log_timescale_increment))
+        log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype, device=device)) / (
+            depth / 2 - 1
+        )
+        inv_timescales = torch.exp(
+            torch.arange(depth / 2, device=device).type(dtype) * (-log_timescale_increment)
+        )
         inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
-        scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(inv_timescales, [1, 1, -1])
+        scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(
+            inv_timescales, [1, 1, -1]
+        )
         encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
         return encoding.type(dtype)
 
     def forward(self, x):
         batch_size, timesteps, input_dim = x.size()
-        positions = torch.arange(1, timesteps+1, device=x.device)[None, :]
+        positions = torch.arange(1, timesteps + 1, device=x.device)[None, :]
         position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
 
         return x + position_encoding
 
-class StreamSinusoidalPositionEncoder(torch.nn.Module):
-    '''
 
-    '''
+class StreamSinusoidalPositionEncoder(torch.nn.Module):
+    """ """
+
     def __int__(self, d_model=80, dropout_rate=0.1):
         pass
 
-    def encode(self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32):
+    def encode(
+        self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32
+    ):
         batch_size = positions.size(0)
         positions = positions.type(dtype)
         log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype)) / (depth / 2 - 1)
         inv_timescales = torch.exp(torch.arange(depth / 2).type(dtype) * (-log_timescale_increment))
         inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
-        scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(inv_timescales, [1, 1, -1])
+        scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(
+            inv_timescales, [1, 1, -1]
+        )
         encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
         return encoding.type(dtype)
 
@@ -432,9 +439,10 @@
         if cache is not None:
             start_idx = cache["start_idx"]
             cache["start_idx"] += timesteps
-        positions = torch.arange(1, timesteps+start_idx+1)[None, :]
+        positions = torch.arange(1, timesteps + start_idx + 1)[None, :]
         position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
-        return x + position_encoding[:, start_idx: start_idx + timesteps]
+        return x + position_encoding[:, start_idx : start_idx + timesteps]
+
 
 class StreamingRelPositionalEncoding(torch.nn.Module):
     """Relative positional encoding.
@@ -444,9 +452,7 @@
         dropout_rate: Dropout rate.
     """
 
-    def __init__(
-        self, size: int, dropout_rate: float = 0.0, max_len: int = 5000
-    ) -> None:
+    def __init__(self, size: int, dropout_rate: float = 0.0, max_len: int = 5000) -> None:
         """Construct a RelativePositionalEncoding object."""
         super().__init__()
 
@@ -477,8 +483,7 @@
 
         position = torch.arange(0, time1, dtype=torch.float32).unsqueeze(1)
         div_term = torch.exp(
-            torch.arange(0, self.size, 2, dtype=torch.float32)
-            * -(math.log(10000.0) / self.size)
+            torch.arange(0, self.size, 2, dtype=torch.float32) * -(math.log(10000.0) / self.size)
         )
 
         pe_positive[:, 0::2] = torch.sin(position * div_term)
@@ -489,9 +494,7 @@
         pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
         pe_negative = pe_negative[1:].unsqueeze(0)
 
-        self.pe = torch.cat([pe_positive, pe_negative], dim=1).to(
-            dtype=x.dtype, device=x.device
-        )
+        self.pe = torch.cat([pe_positive, pe_negative], dim=1).to(dtype=x.dtype, device=x.device)
 
     def forward(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor:
         """Compute positional encoding.
@@ -505,9 +508,7 @@
 
         time1 = x.size(1) + left_context
 
-        pos_enc = self.pe[
-            :, self.pe.size(1) // 2 - time1 + 1 : self.pe.size(1) // 2 + x.size(1)
-        ]
+        pos_enc = self.pe[:, self.pe.size(1) // 2 - time1 + 1 : self.pe.size(1) // 2 + x.size(1)]
         pos_enc = self.dropout(pos_enc)
 
         return pos_enc
@@ -516,14 +517,17 @@
 class ScaledSinuEmbedding(torch.nn.Module):
     def __init__(self, dim):
         super().__init__()
-        self.scale = torch.nn.Parameter(torch.ones(1,))
-        inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
-        self.register_buffer('inv_freq', inv_freq)
+        self.scale = torch.nn.Parameter(
+            torch.ones(
+                1,
+            )
+        )
+        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+        self.register_buffer("inv_freq", inv_freq)
 
     def forward(self, x):
         n, device = x.shape[1], x.device
-        t = torch.arange(n, device = device).type_as(self.inv_freq)
-        sinu = einsum('i , j -> i j', t, self.inv_freq)
-        emb = torch.cat((sinu.sin(), sinu.cos()), dim = -1)
+        t = torch.arange(n, device=device).type_as(self.inv_freq)
+        sinu = einsum("i , j -> i j", t, self.inv_freq)
+        emb = torch.cat((sinu.sin(), sinu.cos()), dim=-1)
         return emb * self.scale
-

--
Gitblit v1.9.1