| | |
| | | all_kinds_order = sorted(all_kinds) |
| | | |
| | | mapping_dict = {} |
| | | mapping_dict['dec2label'] = {} |
| | | mapping_dict['label2dec'] = {} |
| | | mapping_dict["dec2label"] = {} |
| | | mapping_dict["label2dec"] = {} |
| | | for i in range(len(all_kinds_order)): |
| | | dec = all_kinds_order[i] |
| | | mapping_dict['dec2label'][dec] = i |
| | | mapping_dict['label2dec'][i] = dec |
| | | mapping_dict["dec2label"][dec] = i |
| | | mapping_dict["label2dec"][i] = dec |
| | | oov_id = len(all_kinds_order) |
| | | mapping_dict['oov'] = oov_id |
| | | mapping_dict["oov"] = oov_id |
| | | return mapping_dict |
| | | |
| | | |
| | |
| | | |
| | | |
| | | def mapping_func(num, mapping_dict): |
| | | if num in mapping_dict['dec2label'].keys(): |
| | | label = mapping_dict['dec2label'][num] |
| | | if num in mapping_dict["dec2label"].keys(): |
| | | label = mapping_dict["dec2label"][num] |
| | | else: |
| | | label = mapping_dict['oov'] |
| | | label = mapping_dict["oov"] |
| | | return label |
| | | |
| | | |
| | |
| | | perms = np.array(list(permutations(range(n_speaker)))).astype(np.float32) |
| | | perms = torch.from_numpy(perms).to(label.device).to(torch.int64) |
| | | perm_labels = [label[:, perm] for perm in perms] |
| | | perm_pse_labels = [create_powerlabel(perm_label.cpu().numpy(), mapping_dict, max_speaker_num). |
| | | to(perm_label.device, non_blocking=True) for perm_label in perm_labels] |
| | | perm_pse_labels = [ |
| | | create_powerlabel(perm_label.cpu().numpy(), mapping_dict, max_speaker_num).to( |
| | | perm_label.device, non_blocking=True |
| | | ) |
| | | for perm_label in perm_labels |
| | | ] |
| | | return perm_labels, perm_pse_labels |
| | | |
| | | |
| | | def generate_min_pse(label, n_speaker, mapping_dict, max_speaker_num, pse_logit, max_olp_speaker_num=3): |
| | | perm_labels, perm_pse_labels = generate_perm_pse(label, n_speaker, mapping_dict, max_speaker_num, |
| | | max_olp_speaker_num=max_olp_speaker_num) |
| | | losses = [F.cross_entropy(input=pse_logit, target=perm_pse_label.to(torch.long)) * len(pse_logit) |
| | | for perm_pse_label in perm_pse_labels] |
| | | def generate_min_pse( |
| | | label, n_speaker, mapping_dict, max_speaker_num, pse_logit, max_olp_speaker_num=3 |
| | | ): |
| | | perm_labels, perm_pse_labels = generate_perm_pse( |
| | | label, n_speaker, mapping_dict, max_speaker_num, max_olp_speaker_num=max_olp_speaker_num |
| | | ) |
| | | losses = [ |
| | | F.cross_entropy(input=pse_logit, target=perm_pse_label.to(torch.long)) * len(pse_logit) |
| | | for perm_pse_label in perm_pse_labels |
| | | ] |
| | | loss = torch.stack(losses) |
| | | min_index = torch.argmin(loss) |
| | | selected_perm_label, selected_pse_label = perm_labels[min_index], perm_pse_labels[min_index] |