From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/models/eend/encoder.py | 53 ++++++++++++++++++++++++++---------------------------
1 files changed, 26 insertions(+), 27 deletions(-)
diff --git a/funasr/models/eend/encoder.py b/funasr/models/eend/encoder.py
index 3065884..0dbd98f 100644
--- a/funasr/models/eend/encoder.py
+++ b/funasr/models/eend/encoder.py
@@ -7,7 +7,7 @@
class MultiHeadSelfAttention(nn.Module):
def __init__(self, n_units, h=8, dropout_rate=0.1):
- super(MultiHeadSelfAttention, self).__init__()
+ super().__init__()
self.linearQ = nn.Linear(n_units, n_units)
self.linearK = nn.Linear(n_units, n_units)
self.linearV = nn.Linear(n_units, n_units)
@@ -20,8 +20,7 @@
q = self.linearQ(x).view(batch_size, -1, self.h, self.d_k)
k = self.linearK(x).view(batch_size, -1, self.h, self.d_k)
v = self.linearV(x).view(batch_size, -1, self.h, self.d_k)
- scores = torch.matmul(
- q.permute(0, 2, 1, 3), k.permute(0, 2, 3, 1)) / math.sqrt(self.d_k)
+ scores = torch.matmul(q.permute(0, 2, 1, 3), k.permute(0, 2, 3, 1)) / math.sqrt(self.d_k)
if x_mask is not None:
x_mask = x_mask.unsqueeze(1)
scores = scores.masked_fill(x_mask == 0, -1e9)
@@ -61,9 +60,7 @@
return
pe = torch.zeros(x.size(1), self.d_model)
if self.reverse:
- position = torch.arange(
- x.size(1) - 1, -1, -1.0, dtype=torch.float32
- ).unsqueeze(1)
+ position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
@@ -82,38 +79,40 @@
class EENDOLATransformerEncoder(nn.Module):
- def __init__(self,
- idim: int,
- n_layers: int,
- n_units: int,
- e_units: int = 2048,
- h: int = 4,
- dropout_rate: float = 0.1,
- use_pos_emb: bool = False):
+ def __init__(
+ self,
+ idim: int,
+ n_layers: int,
+ n_units: int,
+ e_units: int = 2048,
+ h: int = 4,
+ dropout_rate: float = 0.1,
+ use_pos_emb: bool = False,
+ ):
super(EENDOLATransformerEncoder, self).__init__()
self.linear_in = nn.Linear(idim, n_units)
self.lnorm_in = nn.LayerNorm(n_units)
self.n_layers = n_layers
self.dropout = nn.Dropout(dropout_rate)
for i in range(n_layers):
- setattr(self, '{}{:d}'.format("lnorm1_", i),
- nn.LayerNorm(n_units))
- setattr(self, '{}{:d}'.format("self_att_", i),
- MultiHeadSelfAttention(n_units, h))
- setattr(self, '{}{:d}'.format("lnorm2_", i),
- nn.LayerNorm(n_units))
- setattr(self, '{}{:d}'.format("ff_", i),
- PositionwiseFeedForward(n_units, e_units, dropout_rate))
+ setattr(self, "{}{:d}".format("lnorm1_", i), nn.LayerNorm(n_units))
+ setattr(self, "{}{:d}".format("self_att_", i), MultiHeadSelfAttention(n_units, h))
+ setattr(self, "{}{:d}".format("lnorm2_", i), nn.LayerNorm(n_units))
+ setattr(
+ self,
+ "{}{:d}".format("ff_", i),
+ PositionwiseFeedForward(n_units, e_units, dropout_rate),
+ )
self.lnorm_out = nn.LayerNorm(n_units)
def __call__(self, x, x_mask=None):
BT_size = x.shape[0] * x.shape[1]
e = self.linear_in(x.reshape(BT_size, -1))
for i in range(self.n_layers):
- e = getattr(self, '{}{:d}'.format("lnorm1_", i))(e)
- s = getattr(self, '{}{:d}'.format("self_att_", i))(e, x.shape[0], x_mask)
+ e = getattr(self, "{}{:d}".format("lnorm1_", i))(e)
+ s = getattr(self, "{}{:d}".format("self_att_", i))(e, x.shape[0], x_mask)
e = e + self.dropout(s)
- e = getattr(self, '{}{:d}'.format("lnorm2_", i))(e)
- s = getattr(self, '{}{:d}'.format("ff_", i))(e)
+ e = getattr(self, "{}{:d}".format("lnorm2_", i))(e)
+ s = getattr(self, "{}{:d}".format("ff_", i))(e)
e = e + self.dropout(s)
- return self.lnorm_out(e)
\ No newline at end of file
+ return self.lnorm_out(e)
--
Gitblit v1.9.1