志浩
2023-02-23 04a7ce3205ca478fbc3b1415c2dc31a0769d051c
funasr/models/e2e_diar_sond.py
@@ -86,6 +86,8 @@
        )
        self.criterion_bce = SequenceBinaryCrossEntropy(normalize_length=length_normalized_loss)
        self.pse_embedding = self.generate_pse_embedding()
        self.power_weight = torch.from_numpy(2 ** np.arange(max_spk_num)[np.newaxis, np.newaxis, :])
        self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :])
        self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
        self.inter_score_loss_weight = inter_score_loss_weight
@@ -102,8 +104,8 @@
        speech_lengths: torch.Tensor = None,
        profile: torch.Tensor = None,
        profile_lengths: torch.Tensor = None,
        spk_labels: torch.Tensor = None,
        spk_labels_lengths: torch.Tensor = None,
        binary_labels: torch.Tensor = None,
        binary_labels_lengths: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Speaker Encoder + CI Scorer + CD Scorer + Decoder + Calc loss
@@ -116,10 +118,10 @@
                                     espnet2/iterators/chunk_iter_factory.py
            profile: (Batch, N_spk, dim)
            profile_lengths: (Batch,)
            spk_labels: (Batch, frames, input_size)
            spk_labels_lengths: (Batch,)
            binary_labels: (Batch, frames, max_spk_num)
            binary_labels_lengths: (Batch,)
        """
        assert speech.shape[0] == spk_labels.shape[0], (speech.shape, spk_labels.shape)
        assert speech.shape[0] == binary_labels.shape[0], (speech.shape, binary_labels.shape)
        batch_size = speech.shape[0]
        # 1. Network forward
@@ -132,23 +134,25 @@
        # 2. Aggregate time-domain labels to match forward outputs
        if self.label_aggregator is not None:
            spk_labels, spk_labels_lengths = self.label_aggregator(
                spk_labels.unsqueeze(2), spk_labels_lengths
            binary_labels, binary_labels_lengths = self.label_aggregator(
                binary_labels, binary_labels_lengths
            )
            spk_labels = spk_labels.squeeze(2)
        # 2. Calculate power-set encoding (PSE) labels
        raw_pse_labels = torch.sum(binary_labels * self.power_weight, dim=2, keepdim=True)
        pse_labels = torch.argmax(raw_pse_labels == self.int_token_arr, dim=2)
        # If encoder uses conv* as input_layer (i.e., subsampling),
        # the sequence length of 'pred' might be slightly less than the
        # length of 'spk_labels'. Here we force them to be equal.
        length_diff_tolerance = 2
        length_diff = spk_labels.shape[1] - pred.shape[1]
        length_diff = pse_labels.shape[1] - pred.shape[1]
        if 0 < length_diff <= length_diff_tolerance:
            spk_labels = spk_labels[:, 0: pred.shape[1], :]
            pse_labels = pse_labels[:, 0: pred.shape[1]]
        loss_diar = self.classification_loss(pred, spk_labels, spk_labels_lengths)
        loss_diar = self.classification_loss(pred, pse_labels, binary_labels_lengths)
        loss_spk_dis = self.speaker_discrimination_loss(profile, profile_lengths)
        loss_inter_ci, loss_inter_cd = self.internal_score_loss(cd_score, ci_score, spk_labels, spk_labels_lengths)
        label_mask = make_pad_mask(spk_labels_lengths, maxlen=spk_labels.shape[1])
        loss_inter_ci, loss_inter_cd = self.internal_score_loss(cd_score, ci_score, pse_labels, binary_labels_lengths)
        label_mask = make_pad_mask(binary_labels_lengths, maxlen=pse_labels.shape[1])
        loss = (loss_diar + self.speaker_discrimination_loss_weight * loss_spk_dis
                + self.inter_score_loss_weight * (loss_inter_ci + loss_inter_cd))
@@ -164,8 +168,8 @@
            speaker_error,
        ) = self.calc_diarization_error(
            pred=F.embedding(pred.argmax(dim=2) * label_mask, self.pse_embedding),
            label=F.embedding(spk_labels * label_mask, self.pse_embedding),
            length=spk_labels_lengths
            label=F.embedding(pse_labels * label_mask, self.pse_embedding),
            length=binary_labels_lengths
        )
        if speech_scored > 0 and num_frames > 0: