| | |
| | | normalize_speech_speaker: bool = False, |
| | | ignore_id: int = -1, |
| | | speaker_discrimination_loss_weight: float = 1.0, |
| | | inter_score_loss_weight: float = 0.0 |
| | | inter_score_loss_weight: float = 0.0, |
| | | inputs_type: str = "raw", |
| | | ): |
| | | assert check_argument_types() |
| | | |
| | |
| | | ) |
| | | self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss) |
| | | self.pse_embedding = self.generate_pse_embedding() |
| | | # self.register_buffer("pse_embedding", pse_embedding) |
| | | self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]).float() |
| | | # self.register_buffer("power_weight", power_weight) |
| | | self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]).int() |
| | | # self.register_buffer("int_token_arr", int_token_arr) |
| | | self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight |
| | | self.inter_score_loss_weight = inter_score_loss_weight |
| | | self.forward_steps = 0 |
| | | self.inputs_type = inputs_type |
| | | |
| | | def generate_pse_embedding(self): |
| | | embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float) |
| | |
| | | binary_labels: (Batch, frames, max_spk_num) |
| | | binary_labels_lengths: (Batch,) |
| | | """ |
| | | assert speech.shape[0] == binary_labels.shape[0], (speech.shape, binary_labels.shape) |
| | | assert speech.shape[0] <= binary_labels.shape[0], (speech.shape, binary_labels.shape) |
| | | batch_size = speech.shape[0] |
| | | self.forward_steps = self.forward_steps + 1 |
| | | if self.pse_embedding.device != speech.device: |
| | | self.pse_embedding = self.pse_embedding.to(speech.device) |
| | | self.power_weight = self.power_weight.to(speech.device) |
| | | self.int_token_arr = self.int_token_arr.to(speech.device) |
| | | |
| | | # 1. Network forward |
| | | pred, inter_outputs = self.prediction_forward( |
| | | speech, speech_lengths, |
| | |
| | | # the sequence length of 'pred' might be slightly less than the |
| | | # length of 'spk_labels'. Here we force them to be equal. |
| | | length_diff_tolerance = 2 |
| | | length_diff = pse_labels.shape[1] - pred.shape[1] |
| | | if 0 < length_diff <= length_diff_tolerance: |
| | | pse_labels = pse_labels[:, 0: pred.shape[1]] |
| | | length_diff = abs(pse_labels.shape[1] - pred.shape[1]) |
| | | if length_diff <= length_diff_tolerance: |
| | | min_len = min(pred.shape[1], pse_labels.shape[1]) |
| | | pse_labels = pse_labels[:, :min_len] |
| | | pred = pred[:, :min_len] |
| | | cd_score = cd_score[:, :min_len] |
| | | ci_score = ci_score[:, :min_len] |
| | | |
| | | loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths) |
| | | loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths) |
| | |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | if self.encoder is not None: |
| | | if self.encoder is not None and self.inputs_type == "raw": |
| | | speech, speech_lengths = self.encode(speech, speech_lengths) |
| | | speech_mask = ~make_pad_mask(speech_lengths, maxlen=speech.shape[1]) |
| | | speech_mask = speech_mask.to(speech.device).unsqueeze(-1).float() |