From f7f6295327f84fba94d7d96b43a5932cc082df8b Mon Sep 17 00:00:00 2001
From: hnluo <haoneng.lhn@alibaba-inc.com>
Date: 星期五, 10 三月 2023 15:27:38 +0800
Subject: [PATCH] Merge pull request #202 from alibaba-damo-academy/yufan-aslp-patch-1
---
funasr/models/e2e_diar_sond.py | 12 ++++++------
1 files changed, 6 insertions(+), 6 deletions(-)
diff --git a/funasr/models/e2e_diar_sond.py b/funasr/models/e2e_diar_sond.py
index e68d16b..419c813 100644
--- a/funasr/models/e2e_diar_sond.py
+++ b/funasr/models/e2e_diar_sond.py
@@ -85,12 +85,12 @@
normalize_length=length_normalized_loss,
)
self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
- pse_embedding = self.generate_pse_embedding()
- self.register_buffer("pse_embedding", pse_embedding)
- power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
- self.register_buffer("power_weight", power_weight)
- int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
- self.register_buffer("int_token_arr", int_token_arr)
+ self.pse_embedding = self.generate_pse_embedding()
+ # self.register_buffer("pse_embedding", pse_embedding)
+ self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
+ # self.register_buffer("power_weight", power_weight)
+ self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
+ # self.register_buffer("int_token_arr", int_token_arr)
self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
self.inter_score_loss_weight = inter_score_loss_weight
self.forward_steps = 0
--
Gitblit v1.9.1