志浩
2023-02-27 65d1005fd2cd5566ac819aa6f41e43dff9c8a691
funasr/models/e2e_diar_sond.py
@@ -90,6 +90,7 @@
        self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :])
        self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight
        self.inter_score_loss_weight = inter_score_loss_weight
        self.forward_steps = 0
    def generate_pse_embedding(self):
        embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float)
@@ -123,7 +124,7 @@
        """
        assert speech.shape[0] == binary_labels.shape[0], (speech.shape, binary_labels.shape)
        batch_size = speech.shape[0]
        self.forward_steps = self.forward_steps + 1
        # 1. Network forward
        pred, inter_outputs = self.prediction_forward(
            speech, speech_lengths,
@@ -198,6 +199,7 @@
            cf=cf,
            acc=acc,
            der=der,
            forward_steps=self.forward_steps,
        )
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
@@ -262,8 +264,10 @@
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        spk_labels: torch.Tensor = None,
        spk_labels_lengths: torch.Tensor = None,
        profile: torch.Tensor = None,
        profile_lengths: torch.Tensor = None,
        binary_labels: torch.Tensor = None,
        binary_labels_lengths: torch.Tensor = None,
    ) -> Dict[str, torch.Tensor]:
        feats, feats_lengths = self._extract_feats(speech, speech_lengths)
        return {"feats": feats, "feats_lengths": feats_lengths}