From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/eend/encoder_decoder_attractor.py |   30 ++++++++++++++++++++++--------
 1 files changed, 22 insertions(+), 8 deletions(-)

diff --git a/funasr/models/eend/encoder_decoder_attractor.py b/funasr/models/eend/encoder_decoder_attractor.py
index 45ac982..6500791 100644
--- a/funasr/models/eend/encoder_decoder_attractor.py
+++ b/funasr/models/eend/encoder_decoder_attractor.py
@@ -25,26 +25,40 @@
         max_zlen = torch.max(zlens).to(torch.int).item()
         zeros = [self.enc0_dropout(z) for z in zeros]
         zeros = nn.utils.rnn.pad_sequence(zeros, batch_first=True, padding_value=-1)
-        zeros = nn.utils.rnn.pack_padded_sequence(zeros, zlens, batch_first=True, enforce_sorted=False)
+        zeros = nn.utils.rnn.pack_padded_sequence(
+            zeros, zlens, batch_first=True, enforce_sorted=False
+        )
         attractors, (_, _) = self.decoder(zeros, (hx, cx))
-        attractors = nn.utils.rnn.pad_packed_sequence(attractors, batch_first=True, padding_value=-1,
-                                                      total_length=max_zlen)[0]
-        attractors = [att[:zlens[i].to(torch.int).item()] for i, att in enumerate(attractors)]
+        attractors = nn.utils.rnn.pad_packed_sequence(
+            attractors, batch_first=True, padding_value=-1, total_length=max_zlen
+        )[0]
+        attractors = [att[: zlens[i].to(torch.int).item()] for i, att in enumerate(attractors)]
         return attractors
 
     def forward(self, xs, n_speakers):
-        zeros = [torch.zeros(n_spk + 1, self.n_units).to(torch.float32).to(xs[0].device) for n_spk in n_speakers]
+        zeros = [
+            torch.zeros(n_spk + 1, self.n_units).to(torch.float32).to(xs[0].device)
+            for n_spk in n_speakers
+        ]
         attractors = self.forward_core(xs, zeros)
-        labels = torch.cat([torch.from_numpy(np.array([[1] * n_spk + [0]], np.float32)) for n_spk in n_speakers], dim=1)
+        labels = torch.cat(
+            [torch.from_numpy(np.array([[1] * n_spk + [0]], np.float32)) for n_spk in n_speakers],
+            dim=1,
+        )
         labels = labels.to(xs[0].device)
-        logit = torch.cat([self.counter(att).view(-1, n_spk + 1) for att, n_spk in zip(attractors, n_speakers)], dim=1)
+        logit = torch.cat(
+            [self.counter(att).view(-1, n_spk + 1) for att, n_spk in zip(attractors, n_speakers)],
+            dim=1,
+        )
         loss = F.binary_cross_entropy(torch.sigmoid(logit), labels)
 
         attractors = [att[slice(0, att.shape[0] - 1)] for att in attractors]
         return loss, attractors
 
     def estimate(self, xs, max_n_speakers=15):
-        zeros = [torch.zeros(max_n_speakers, self.n_units).to(torch.float32).to(xs[0].device) for _ in xs]
+        zeros = [
+            torch.zeros(max_n_speakers, self.n_units).to(torch.float32).to(xs[0].device) for _ in xs
+        ]
         attractors = self.forward_core(xs, zeros)
         probs = [torch.sigmoid(torch.flatten(self.counter(att))) for att in attractors]
         return attractors, probs

--
Gitblit v1.9.1