From d783b24ba7d8a03dabfa2139fcbf40c216e0ea3d Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 16 三月 2023 19:34:52 +0800
Subject: [PATCH] Merge pull request #199 from alibaba-damo-academy/dev_xw

---
 funasr/models/e2e_diar_sond.py |   13 +++++++------
 1 files changed, 7 insertions(+), 6 deletions(-)

diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py
index e68d16b..258d780 100644
--- a/funasr/models/e2e_diar_sond.py
+++ b/funasr/models/e2e_diar_sond.py
@@ -85,12 +85,12 @@
             normalize_length=length_normalized_loss,
         )
         self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
-        pse_embedding = self.generate_pse_embedding()
-        self.register_buffer("pse_embedding", pse_embedding)
-        power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
-        self.register_buffer("power_weight", power_weight)
-        int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
-        self.register_buffer("int_token_arr", int_token_arr)
+        self.pse_embedding = self.generate_pse_embedding()
+        # self.register_buffer("pse_embedding", pse_embedding)
+        self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
+        # self.register_buffer("power_weight", power_weight)
+        self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
+        # self.register_buffer("int_token_arr", int_token_arr)
         self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
         self.inter_score_loss_weight = inter_score_loss_weight
         self.forward_steps = 0
@@ -342,6 +342,7 @@
 
         if isinstance(self.ci_scorer, AbsEncoder):
             ci_simi = self.ci_scorer(ge_in, ge_len)[0]
+            ci_simi = torch.reshape(ci_simi, [bb, self.max_spk_num, tt]).permute([0, 2, 1])
         else:
             ci_simi = self.ci_scorer(speech_encoder_outputs, speaker_encoder_outputs)
 

--
Gitblit v1.9.1