Yuekai Zhang
2023-03-06 80e6c258cf89b5f11f4e52a4cc5a9cf2e95aa7be
funasr/models/e2e_diar_sond.py
@@ -85,12 +85,12 @@
            normalize_length=length_normalized_loss,
        )
        self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
        pse_embedding = self.generate_pse_embedding()
        self.register_buffer("pse_embedding", pse_embedding)
        power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
        self.register_buffer("power_weight", power_weight)
        int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
        self.register_buffer("int_token_arr", int_token_arr)
        self.pse_embedding = self.generate_pse_embedding()
        # self.register_buffer("pse_embedding", pse_embedding)
        self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
        # self.register_buffer("power_weight", power_weight)
        self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
        # self.register_buffer("int_token_arr", int_token_arr)
        self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
        self.inter_score_loss_weight = inter_score_loss_weight
        self.forward_steps = 0