志浩
2023-02-27 97f8201138f6d89bf42f819ef4fbf69889c7f792
fixbug sond initial
1个文件已修改
12 ■■■■ 已修改文件
funasr/models/e2e_diar_sond.py 12 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_diar_sond.py
@@ -85,12 +85,12 @@
            normalize_length=length_normalized_loss,
        )
        self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
        pse_embedding = self.generate_pse_embedding()
        self.register_buffer("pse_embedding", pse_embedding)
        power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
        self.register_buffer("power_weight", power_weight)
        int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
        self.register_buffer("int_token_arr", int_token_arr)
        self.pse_embedding = self.generate_pse_embedding()
        # self.register_buffer("pse_embedding", pse_embedding)
        self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float()
        # self.register_buffer("power_weight", power_weight)
        self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int()
        # self.register_buffer("int_token_arr", int_token_arr)
        self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
        self.inter_score_loss_weight = inter_score_loss_weight
        self.forward_steps = 0