| | |
| | | self.int_token_arr = torch.from_numpy(np.array(self.token_list).astype(int)[np.newaxis, np.newaxis, :]) |
| | | self.speaker_discrimination_loss_weight = speaker_discrimination_loss_weight |
| | | self.inter_score_loss_weight = inter_score_loss_weight |
| | | self.forward_steps = 0 |
| | | |
| | | def generate_pse_embedding(self): |
| | | embedding = np.zeros((len(self.token_list), self.max_spk_num), dtype=np.float) |
| | |
| | | """ |
| | | assert speech.shape[0] == binary_labels.shape[0], (speech.shape, binary_labels.shape) |
| | | batch_size = speech.shape[0] |
| | | |
| | | self.forward_steps = self.forward_steps + 1 |
| | | # 1. Network forward |
| | | pred, inter_outputs = self.prediction_forward( |
| | | speech, speech_lengths, |
| | |
| | | cf=cf, |
| | | acc=acc, |
| | | der=der, |
| | | forward_steps=self.forward_steps, |
| | | ) |
| | | |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | spk_labels: torch.Tensor = None, |
| | | spk_labels_lengths: torch.Tensor = None, |
| | | profile: torch.Tensor = None, |
| | | profile_lengths: torch.Tensor = None, |
| | | binary_labels: torch.Tensor = None, |
| | | binary_labels_lengths: torch.Tensor = None, |
| | | ) -> Dict[str, torch.Tensor]: |
| | | feats, feats_lengths = self._extract_feats(speech, speech_lengths) |
| | | return {"feats": feats, "feats_lengths": feats_lengths} |