From 04a7ce3205ca478fbc3b1415c2dc31a0769d051c Mon Sep 17 00:00:00 2001
From: 志浩 <neo.dzh@alibaba-inc.com>
Date: 星期四, 23 二月 2023 17:53:04 +0800
Subject: [PATCH] sond pipeline
---
funasr/models/e2e_diar_sond.py | 34 +++++++++++++++++++---------------
1 files changed, 19 insertions(+), 15 deletions(-)
diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py
index 7b6e955..f55bbf6 100644
--- a/funasr/models/e2e_diar_sond.py
+++ b/funasr/models/e2e_diar_sond.py
@@ -86,6 +86,8 @@
)
self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
self.pse_embedding = self.generate_pse_embedding()
+ self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :])
+ self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :])
self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
self.inter_score_loss_weight = inter_score_loss_weight
@@ -102,8 +104,8 @@
speech_lengths: torch.Tensor = None,
profile: torch.Tensor = None,
profile_lengths: torch.Tensor = None,
- spk_labels: torch.Tensor = None,
- spk_labels_lengths: torch.Tensor = None,
+ binary_labels: torch.Tensor = None,
+ binary_labels_lengths: torch.Tensor = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
"""Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
@@ -116,10 +118,10 @@
espnet2/iterators/chunk_iter_factory.py
profile: (Batch, N_spk, dim)
profile_lengths: (Batch,)
- spk_labels: (Batch, frames, input_size)
- spk_labels_lengths: (Batch,)
+ binary_labels: (Batch, frames, max_spk_num)
+ binary_labels_lengths: (Batch,)
"""
- assert speech.shape[0] == spk_labels.shape[0], (speech.shape, spk_labels.shape)
+ assert speech.shape[0] == binary_labels.shape[0], (speech.shape, binary_labels.shape)
batch_size = speech.shape[0]
# 1. Network forward
@@ -132,23 +134,25 @@
# 2. Aggregate time-domain labels to match forward outputs
if self.label_aggregator is not None:
- spk_labels, spk_labels_lengths = self.label_aggregator(
- spk_labels.unsqueeze(2), spk_labels_lengths
+ binary_labels, binary_labels_lengths = self.label_aggregator(
+ binary_labels, binary_labels_lengths
)
- spk_labels = spk_labels.squeeze(2)
+ # 2. Calculate power-set encoding (PSE) labels
+ raw_pse_labels = torch.sum(binary_labels * self.power_weight, dim=2, keepdim=True)
+ pse_labels = torch.argmax(raw_pse_labels == self.int_token_arr, dim=2)
# If encoder uses conv* as input_layer (i.e., subsampling),
# the sequence length of 'pred' might be slightly less than the
# length of 'spk_labels'. Here we force them to be equal.
length_diff_tolerance = 2
- length_diff = spk_labels.shape[1] - pred.shape[1]
+ length_diff = pse_labels.shape[1] - pred.shape[1]
if 0 < length_diff <= length_diff_tolerance:
- spk_labels = spk_labels[:, 0: pred.shape[1], :]
+ pse_labels = pse_labels[:, 0: pred.shape[1]]
- loss_diar = self.classification_loss(pred, spk_labels, spk_labels_lengths)
+ loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths)
loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths)
- loss_inter_ci, loss_inter_cd = self.internal_score_loss(cd_score, ci_score, spk_labels, spk_labels_lengths)
- label_mask = make_pad_mask(spk_labels_lengths, maxlen=spk_labels.shape[1])
+ loss_inter_ci, loss_inter_cd = self.internal_score_loss(cd_score, ci_score, pse_labels, binary_labels_lengths)
+ label_mask = make_pad_mask(binary_labels_lengths, maxlen=pse_labels.shape[1])
loss = (loss_diar + self.speaker_discrimination_loss_weight * loss_spk_dis
+ self.inter_score_loss_weight * (loss_inter_ci + loss_inter_cd))
@@ -164,8 +168,8 @@
speaker_error,
) = self.calc_diarization_error(
pred=F.embedding(pred.argmax(dim=2) * label_mask, self.pse_embedding),
- label=F.embedding(spk_labels * label_mask, self.pse_embedding),
- length=spk_labels_lengths
+ label=F.embedding(pse_labels * label_mask, self.pse_embedding),
+ length=binary_labels_lengths
)
if speech_scored > 0 and num_frames > 0:
--
Gitblit v1.9.1