From 04a7ce3205ca478fbc3b1415c2dc31a0769d051c Mon Sep 17 00:00:00 2001
From: 志浩 <neo.dzh@alibaba-inc.com>
Date: 星期四, 23 二月 2023 17:53:04 +0800
Subject: [PATCH] sond pipeline

---
 funasr/models/e2e_diar_sond.py |   34 +++++++++++++++++++---------------
 1 files changed, 19 insertions(+), 15 deletions(-)

diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py
index 7b6e955..f55bbf6 100644
--- a/funasr/models/e2e_diar_sond.py
+++ b/funasr/models/e2e_diar_sond.py
@@ -86,6 +86,8 @@
         )
         self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
         self.pse_embedding = self.generate_pse_embedding()
+        self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :])
+        self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :])
         self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
         self.inter_score_loss_weight = inter_score_loss_weight
 
@@ -102,8 +104,8 @@
         speech_lengths: torch.Tensor = None,
         profile: torch.Tensor = None,
         profile_lengths: torch.Tensor = None,
-        spk_labels: torch.Tensor = None,
-        spk_labels_lengths: torch.Tensor = None,
+        binary_labels: torch.Tensor = None,
+        binary_labels_lengths: torch.Tensor = None,
     ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
         """Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
 
@@ -116,10 +118,10 @@
                                      espnet2/iterators/chunk_iter_factory.py
             profile: (Batch, N_spk, dim)
             profile_lengths: (Batch,)
-            spk_labels: (Batch, frames, input_size)
-            spk_labels_lengths: (Batch,)
+            binary_labels: (Batch, frames, max_spk_num)
+            binary_labels_lengths: (Batch,)
         """
-        assert speech.shape[0] == spk_labels.shape[0], (speech.shape, spk_labels.shape)
+        assert speech.shape[0] == binary_labels.shape[0], (speech.shape, binary_labels.shape)
         batch_size = speech.shape[0]
 
         # 1. Network forward
@@ -132,23 +134,25 @@
 
         # 2. Aggregate time-domain labels to match forward outputs
         if self.label_aggregator is not None:
-            spk_labels, spk_labels_lengths = self.label_aggregator(
-                spk_labels.unsqueeze(2), spk_labels_lengths
+            binary_labels, binary_labels_lengths = self.label_aggregator(
+                binary_labels, binary_labels_lengths
             )
-            spk_labels = spk_labels.squeeze(2)
+        # 2. Calculate power-set encoding (PSE) labels
+        raw_pse_labels = torch.sum(binary_labels * self.power_weight, dim=2, keepdim=True)
+        pse_labels = torch.argmax(raw_pse_labels == self.int_token_arr, dim=2)
 
         # If encoder uses conv* as input_layer (i.e., subsampling),
         # the sequence length of 'pred' might be slightly less than the
         # length of 'spk_labels'. Here we force them to be equal.
         length_diff_tolerance = 2
-        length_diff = spk_labels.shape[1] - pred.shape[1]
+        length_diff = pse_labels.shape[1] - pred.shape[1]
         if 0 < length_diff <= length_diff_tolerance:
-            spk_labels = spk_labels[:, 0: pred.shape[1], :]
+            pse_labels = pse_labels[:, 0: pred.shape[1]]
 
-        loss_diar = self.classification_loss(pred, spk_labels, spk_labels_lengths)
+        loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths)
         loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths)
-        loss_inter_ci, loss_inter_cd = self.internal_score_loss(cd_score, ci_score, spk_labels, spk_labels_lengths)
-        label_mask = make_pad_mask(spk_labels_lengths, maxlen=spk_labels.shape[1])
+        loss_inter_ci, loss_inter_cd = self.internal_score_loss(cd_score, ci_score, pse_labels, binary_labels_lengths)
+        label_mask = make_pad_mask(binary_labels_lengths, maxlen=pse_labels.shape[1])
         loss = (loss_diar + self.speaker_discrimination_loss_weight * loss_spk_dis
                 + self.inter_score_loss_weight * (loss_inter_ci + loss_inter_cd))
 
@@ -164,8 +168,8 @@
             speaker_error,
         ) = self.calc_diarization_error(
             pred=F.embedding(pred.argmax(dim=2) * label_mask, self.pse_embedding),
-            label=F.embedding(spk_labels * label_mask, self.pse_embedding),
-            length=spk_labels_lengths
+            label=F.embedding(pse_labels * label_mask, self.pse_embedding),
+            length=binary_labels_lengths
         )
 
         if speech_scored > 0 and num_frames > 0:

--
Gitblit v1.9.1