| | |
| | | normalize_length=length_normalized_loss, |
| | | ) |
| | | self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss) |
| | | pse_embedding = self.generate_pse_embedding() |
| | | self.register_buffer("pse_embedding", pse_embedding) |
| | | power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float() |
| | | self.register_buffer("power_weight", power_weight) |
| | | int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int() |
| | | self.register_buffer("int_token_arr", int_token_arr) |
| | | self.pse_embedding = self.generate_pse_embedding() |
| | | # self.register_buffer("pse_embedding", pse_embedding) |
| | | self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float() |
| | | # self.register_buffer("power_weight", power_weight) |
| | | self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int() |
| | | # self.register_buffer("int_token_arr", int_token_arr) |
| | | self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight |
| | | self.inter_score_loss_weight = inter_score_loss_weight |
| | | self.forward_steps = 0 |