| | |
| | | def forward_post_net(self, logits, ilens): |
| | | maxlen = torch.max(ilens).to(torch.int).item() |
| | | logits = nn.utils.rnn.pad_sequence(logits, batch_first=True, padding_value=-1) |
| | | logits = nn.utils.rnn.pack_padded_sequence(logits, ilens, batch_first=True, enforce_sorted=False) |
| | | logits = nn.utils.rnn.pack_padded_sequence(logits, ilens.cpu().to(torch.int64), batch_first=True, enforce_sorted=False) |
| | | outputs, (_, _) = self.postnet(logits) |
| | | outputs = nn.utils.rnn.pad_packed_sequence(outputs, batch_first=True, padding_value=-1, total_length=maxlen)[0] |
| | | outputs = [output[:ilens[i].to(torch.int).item()] for i, output in enumerate(outputs)] |
| | |
| | | pred[i] = pred[i - 1] |
| | | else: |
| | | pred[i] = 0 |
| | | pred = [self.reporter.inv_mapping_func(i, self.mapping_dict) for i in pred] |
| | | pred = [self.inv_mapping_func(i) for i in pred] |
| | | decisions = [bin(num)[2:].zfill(self.max_n_speaker)[::-1] for num in pred] |
| | | decisions = torch.from_numpy( |
| | | np.stack([np.array([int(i) for i in dec]) for dec in decisions], axis=0)).to(logit.device).to( |
| | |
| | | decisions = decisions[:, :n_speaker] |
| | | return decisions |
| | | |
| | | def inv_mapping_func(self, label): |
| | | |
| | | if not isinstance(label, int): |
| | | label = int(label) |
| | | if label in self.mapping_dict['label2dec'].keys(): |
| | | num = self.mapping_dict['label2dec'][label] |
| | | else: |
| | | num = -1 |
| | | return num |
| | | |
| | | def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]: |
| | | pass |