| | |
| | | ) |
| | | self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss) |
| | | self.pse_embedding = self.generate_pse_embedding() |
| | | self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :]) |
| | | self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]) |
| | | self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight |
| | | self.inter_score_loss_weight = inter_score_loss_weight |
| | | |
| | |
| | | speech_lengths: torch.Tensor = None, |
| | | profile: torch.Tensor = None, |
| | | profile_lengths: torch.Tensor = None, |
| | | spk_labels: torch.Tensor = None, |
| | | spk_labels_lengths: torch.Tensor = None, |
| | | binary_labels: torch.Tensor = None, |
| | | binary_labels_lengths: torch.Tensor = None, |
| | | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
| | | """Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss |
| | | |
| | |
| | | espnet2/iterators/chunk_iter_factory.py |
| | | profile: (Batch, N_spk, dim) |
| | | profile_lengths: (Batch,) |
| | | spk_labels: (Batch, frames, input_size) |
| | | spk_labels_lengths: (Batch,) |
| | | binary_labels: (Batch, frames, max_spk_num) |
| | | binary_labels_lengths: (Batch,) |
| | | """ |
| | | assert speech.shape[0] == spk_labels.shape[0], (speech.shape, spk_labels.shape) |
| | | assert speech.shape[0] == binary_labels.shape[0], (speech.shape, binary_labels.shape) |
| | | batch_size = speech.shape[0] |
| | | |
| | | # 1. Network forward |
| | |
| | | |
| | | # 2. Aggregate time-domain labels to match forward outputs |
| | | if self.label_aggregator is not None: |
| | | spk_labels, spk_labels_lengths = self.label_aggregator( |
| | | spk_labels.unsqueeze(2), spk_labels_lengths |
| | | binary_labels, binary_labels_lengths = self.label_aggregator( |
| | | binary_labels, binary_labels_lengths |
| | | ) |
| | | spk_labels = spk_labels.squeeze(2) |
| | | # 2. Calculate power-set encoding (PSE) labels |
| | | raw_pse_labels = torch.sum(binary_labels * self.power_weight, dim=2, keepdim=True) |
| | | pse_labels = torch.argmax(raw_pse_labels == self.int_token_arr, dim=2) |
| | | |
| | | # If encoder uses conv* as input_layer (i.e., subsampling), |
| | | # the sequence length of 'pred' might be slightly less than the |
| | | # length of 'spk_labels'. Here we force them to be equal. |
| | | length_diff_tolerance = 2 |
| | | length_diff = spk_labels.shape[1] - pred.shape[1] |
| | | length_diff = pse_labels.shape[1] - pred.shape[1] |
| | | if 0 < length_diff <= length_diff_tolerance: |
| | | spk_labels = spk_labels[:, 0: pred.shape[1], :] |
| | | pse_labels = pse_labels[:, 0: pred.shape[1]] |
| | | |
| | | loss_diar = self.classification_loss(pred, spk_labels, spk_labels_lengths) |
| | | loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths) |
| | | loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths) |
| | | loss_inter_ci, loss_inter_cd = self.internal_score_loss(cd_score, ci_score, spk_labels, spk_labels_lengths) |
| | | label_mask = make_pad_mask(spk_labels_lengths, maxlen=spk_labels.shape[1]) |
| | | loss_inter_ci, loss_inter_cd = self.internal_score_loss(cd_score, ci_score, pse_labels, binary_labels_lengths) |
| | | label_mask = make_pad_mask(binary_labels_lengths, maxlen=pse_labels.shape[1]) |
| | | loss = (loss_diar + self.speaker_discrimination_loss_weight * loss_spk_dis |
| | | + self.inter_score_loss_weight * (loss_inter_ci + loss_inter_cd)) |
| | | |
| | |
| | | speaker_error, |
| | | ) = self.calc_diarization_error( |
| | | pred=F.embedding(pred.argmax(dim=2) * label_mask, self.pse_embedding), |
| | | label=F.embedding(spk_labels * label_mask, self.pse_embedding), |
| | | length=spk_labels_lengths |
| | | label=F.embedding(pse_labels * label_mask, self.pse_embedding), |
| | | length=binary_labels_lengths |
| | | ) |
| | | |
| | | if speech_scored > 0 and num_frames > 0: |