hnluo
2023-03-10 1eacb8ae811a7e68c296523e6a8b4e72eb024aa2
funasr/models/e2e_diar_sond.py
@@ -85,12 +85,12 @@
            normalize_length=length_normalized_loss,
        )
        self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
        pse_embedding = self.generate_pse_embedding()
        self.register_buffer("pse_embedding", pse_embedding)
        power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
        self.register_buffer("power_weight", power_weight)
        int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
        self.register_buffer("int_token_arr", int_token_arr)
        self.pse_embedding = self.generate_pse_embedding()
        # self.register_buffer("pse_embedding", pse_embedding)
        self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
        # self.register_buffer("power_weight", power_weight)
        self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
        # self.register_buffer("int_token_arr", int_token_arr)
        self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
        self.inter_score_loss_weight = inter_score_loss_weight
        self.forward_steps = 0