From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/eend/encoder.py |   53 ++++++++++++++++++++++++++---------------------------
 1 files changed, 26 insertions(+), 27 deletions(-)

diff --git a/funasr/models/eend/encoder.py b/funasr/models/eend/encoder.py
index 3065884..0dbd98f 100644
--- a/funasr/models/eend/encoder.py
+++ b/funasr/models/eend/encoder.py
@@ -7,7 +7,7 @@
 
 class MultiHeadSelfAttention(nn.Module):
     def __init__(self, n_units, h=8, dropout_rate=0.1):
-        super(MultiHeadSelfAttention, self).__init__()
+        super().__init__()
         self.linearQ = nn.Linear(n_units, n_units)
         self.linearK = nn.Linear(n_units, n_units)
         self.linearV = nn.Linear(n_units, n_units)
@@ -20,8 +20,7 @@
         q = self.linearQ(x).view(batch_size, -1, self.h, self.d_k)
         k = self.linearK(x).view(batch_size, -1, self.h, self.d_k)
         v = self.linearV(x).view(batch_size, -1, self.h, self.d_k)
-        scores = torch.matmul(
-            q.permute(0, 2, 1, 3), k.permute(0, 2, 3, 1)) / math.sqrt(self.d_k)
+        scores = torch.matmul(q.permute(0, 2, 1, 3), k.permute(0, 2, 3, 1)) / math.sqrt(self.d_k)
         if x_mask is not None:
             x_mask = x_mask.unsqueeze(1)
             scores = scores.masked_fill(x_mask == 0, -1e9)
@@ -61,9 +60,7 @@
                 return
         pe = torch.zeros(x.size(1), self.d_model)
         if self.reverse:
-            position = torch.arange(
-                x.size(1) - 1, -1, -1.0, dtype=torch.float32
-            ).unsqueeze(1)
+            position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
         else:
             position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
         div_term = torch.exp(
@@ -82,38 +79,40 @@
 
 
 class EENDOLATransformerEncoder(nn.Module):
-    def __init__(self,
-                 idim: int,
-                 n_layers: int,
-                 n_units: int,
-                 e_units: int = 2048,
-                 h: int = 4,
-                 dropout_rate: float = 0.1,
-                 use_pos_emb: bool = False):
+    def __init__(
+        self,
+        idim: int,
+        n_layers: int,
+        n_units: int,
+        e_units: int = 2048,
+        h: int = 4,
+        dropout_rate: float = 0.1,
+        use_pos_emb: bool = False,
+    ):
         super(EENDOLATransformerEncoder, self).__init__()
         self.linear_in = nn.Linear(idim, n_units)
         self.lnorm_in = nn.LayerNorm(n_units)
         self.n_layers = n_layers
         self.dropout = nn.Dropout(dropout_rate)
         for i in range(n_layers):
-            setattr(self, '{}{:d}'.format("lnorm1_", i),
-                    nn.LayerNorm(n_units))
-            setattr(self, '{}{:d}'.format("self_att_", i),
-                    MultiHeadSelfAttention(n_units, h))
-            setattr(self, '{}{:d}'.format("lnorm2_", i),
-                    nn.LayerNorm(n_units))
-            setattr(self, '{}{:d}'.format("ff_", i),
-                    PositionwiseFeedForward(n_units, e_units, dropout_rate))
+            setattr(self, "{}{:d}".format("lnorm1_", i), nn.LayerNorm(n_units))
+            setattr(self, "{}{:d}".format("self_att_", i), MultiHeadSelfAttention(n_units, h))
+            setattr(self, "{}{:d}".format("lnorm2_", i), nn.LayerNorm(n_units))
+            setattr(
+                self,
+                "{}{:d}".format("ff_", i),
+                PositionwiseFeedForward(n_units, e_units, dropout_rate),
+            )
         self.lnorm_out = nn.LayerNorm(n_units)
 
     def __call__(self, x, x_mask=None):
         BT_size = x.shape[0] * x.shape[1]
         e = self.linear_in(x.reshape(BT_size, -1))
         for i in range(self.n_layers):
-            e = getattr(self, '{}{:d}'.format("lnorm1_", i))(e)
-            s = getattr(self, '{}{:d}'.format("self_att_", i))(e, x.shape[0], x_mask)
+            e = getattr(self, "{}{:d}".format("lnorm1_", i))(e)
+            s = getattr(self, "{}{:d}".format("self_att_", i))(e, x.shape[0], x_mask)
             e = e + self.dropout(s)
-            e = getattr(self, '{}{:d}'.format("lnorm2_", i))(e)
-            s = getattr(self, '{}{:d}'.format("ff_", i))(e)
+            e = getattr(self, "{}{:d}".format("lnorm2_", i))(e)
+            s = getattr(self, "{}{:d}".format("ff_", i))(e)
             e = e + self.dropout(s)
-        return self.lnorm_out(e)
\ No newline at end of file
+        return self.lnorm_out(e)

--
Gitblit v1.9.1