From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/models/eend/encoder_decoder_attractor.py | 30 ++++++++++++++++++++++--------
1 files changed, 22 insertions(+), 8 deletions(-)
diff --git a/funasr/models/eend/encoder_decoder_attractor.py b/funasr/models/eend/encoder_decoder_attractor.py
index 45ac982..6500791 100644
--- a/funasr/models/eend/encoder_decoder_attractor.py
+++ b/funasr/models/eend/encoder_decoder_attractor.py
@@ -25,26 +25,40 @@
max_zlen = torch.max(zlens).to(torch.int).item()
zeros = [self.enc0_dropout(z) for z in zeros]
zeros = nn.utils.rnn.pad_sequence(zeros, batch_first=True, padding_value=-1)
- zeros = nn.utils.rnn.pack_padded_sequence(zeros, zlens, batch_first=True, enforce_sorted=False)
+ zeros = nn.utils.rnn.pack_padded_sequence(
+ zeros, zlens, batch_first=True, enforce_sorted=False
+ )
attractors, (_, _) = self.decoder(zeros, (hx, cx))
- attractors = nn.utils.rnn.pad_packed_sequence(attractors, batch_first=True, padding_value=-1,
- total_length=max_zlen)[0]
- attractors = [att[:zlens[i].to(torch.int).item()] for i, att in enumerate(attractors)]
+ attractors = nn.utils.rnn.pad_packed_sequence(
+ attractors, batch_first=True, padding_value=-1, total_length=max_zlen
+ )[0]
+ attractors = [att[: zlens[i].to(torch.int).item()] for i, att in enumerate(attractors)]
return attractors
def forward(self, xs, n_speakers):
- zeros = [torch.zeros(n_spk + 1, self.n_units).to(torch.float32).to(xs[0].device) for n_spk in n_speakers]
+ zeros = [
+ torch.zeros(n_spk + 1, self.n_units).to(torch.float32).to(xs[0].device)
+ for n_spk in n_speakers
+ ]
attractors = self.forward_core(xs, zeros)
- labels = torch.cat([torch.from_numpy(np.array([[1] * n_spk + [0]], np.float32)) for n_spk in n_speakers], dim=1)
+ labels = torch.cat(
+ [torch.from_numpy(np.array([[1] * n_spk + [0]], np.float32)) for n_spk in n_speakers],
+ dim=1,
+ )
labels = labels.to(xs[0].device)
- logit = torch.cat([self.counter(att).view(-1, n_spk + 1) for att, n_spk in zip(attractors, n_speakers)], dim=1)
+ logit = torch.cat(
+ [self.counter(att).view(-1, n_spk + 1) for att, n_spk in zip(attractors, n_speakers)],
+ dim=1,
+ )
loss = F.binary_cross_entropy(torch.sigmoid(logit), labels)
attractors = [att[slice(0, att.shape[0] - 1)] for att in attractors]
return loss, attractors
def estimate(self, xs, max_n_speakers=15):
- zeros = [torch.zeros(max_n_speakers, self.n_units).to(torch.float32).to(xs[0].device) for _ in xs]
+ zeros = [
+ torch.zeros(max_n_speakers, self.n_units).to(torch.float32).to(xs[0].device) for _ in xs
+ ]
attractors = self.forward_core(xs, zeros)
probs = [torch.sigmoid(torch.flatten(self.counter(att))) for att in attractors]
return attractors, probs
--
Gitblit v1.9.1