From 65d1005fd2cd5566ac819aa6f41e43dff9c8a691 Mon Sep 17 00:00:00 2001
From: 志浩 <neo.dzh@alibaba-inc.com>
Date: 星期一, 27 二月 2023 14:46:01 +0800
Subject: [PATCH] fixbug for sd and sv

---
 funasr/models/e2e_diar_sond.py |   10 +++++++---
 1 files changed, 7 insertions(+), 3 deletions(-)

diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py
index f55bbf6..ad54723 100644
--- a/funasr/models/e2e_diar_sond.py
+++ b/funasr/models/e2e_diar_sond.py
@@ -90,6 +90,7 @@
         self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :])
         self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
         self.inter_score_loss_weight = inter_score_loss_weight
+        self.forward_steps = 0
 
     def generate_pse_embedding(self):
         embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float)
@@ -123,7 +124,7 @@
         """
         assert speech.shape[0] == binary_labels.shape[0], (speech.shape, binary_labels.shape)
         batch_size = speech.shape[0]
-
+        self.forward_steps = self.forward_steps + 1
         # 1. Network forward
         pred, inter_outputs = self.prediction_forward(
             speech, speech_lengths,
@@ -198,6 +199,7 @@
             cf=cf,
             acc=acc,
             der=der,
+            forward_steps=self.forward_steps,
         )
 
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
@@ -262,8 +264,10 @@
         self,
         speech: torch.Tensor,
         speech_lengths: torch.Tensor,
-        spk_labels: torch.Tensor = None,
-        spk_labels_lengths: torch.Tensor = None,
+        profile: torch.Tensor = None,
+        profile_lengths: torch.Tensor = None,
+        binary_labels: torch.Tensor = None,
+        binary_labels_lengths: torch.Tensor = None,
     ) -> Dict[str, torch.Tensor]:
         feats, feats_lengths = self._extract_feats(speech, speech_lengths)
         return {"feats": feats, "feats_lengths": feats_lengths}

--
Gitblit v1.9.1