| | |
| | | import numpy as np |
| | | import torch |
| | | import torch.multiprocessing |
| | | import torch.nn.functional as F |
| | | from itertools import combinations |
| | | from itertools import permutations |
| | | |
| | | |
| | | def generate_mapping_dict(max_speaker_num=6, max_olp_speaker_num=3): |
| | | all_kinds = [] |
| | | all_kinds.append(0) |
| | | for i in range(max_olp_speaker_num): |
| | | selected_num = i + 1 |
| | | coms = np.array(list(combinations(np.arange(max_speaker_num), selected_num))) |
| | | for com in coms: |
| | | tmp = np.zeros(max_speaker_num) |
| | | tmp[com] = 1 |
| | | item = int(raw_dec_trans(tmp.reshape(1, -1), max_speaker_num)[0]) |
| | | all_kinds.append(item) |
| | | all_kinds_order = sorted(all_kinds) |
| | | |
| | | mapping_dict = {} |
| | | mapping_dict['dec2label'] = {} |
| | | mapping_dict['label2dec'] = {} |
| | | for i in range(len(all_kinds_order)): |
| | | dec = all_kinds_order[i] |
| | | mapping_dict['dec2label'][dec] = i |
| | | mapping_dict['label2dec'][i] = dec |
| | | oov_id = len(all_kinds_order) |
| | | mapping_dict['oov'] = oov_id |
| | | return mapping_dict |
| | | |
| | | |
| | | def raw_dec_trans(x, max_speaker_num): |
| | | num_list = [] |
| | | for i in range(max_speaker_num): |
| | | num_list.append(x[:, i]) |
| | | base = 1 |
| | | T = x.shape[0] |
| | | res = np.zeros((T)) |
| | | for num in num_list: |
| | | res += num * base |
| | | base = base * 2 |
| | | return res |
| | | |
| | | |
| | | def mapping_func(num, mapping_dict): |
| | | if num in mapping_dict['dec2label'].keys(): |
| | | label = mapping_dict['dec2label'][num] |
| | | else: |
| | | label = mapping_dict['oov'] |
| | | return label |
| | | |
| | | |
| | | def dec_trans(x, max_speaker_num, mapping_dict): |
| | | num_list = [] |
| | | for i in range(max_speaker_num): |
| | | num_list.append(x[:, i]) |
| | | base = 1 |
| | | T = x.shape[0] |
| | | res = np.zeros((T)) |
| | | for num in num_list: |
| | | res += num * base |
| | | base = base * 2 |
| | | res = np.array([mapping_func(i, mapping_dict) for i in res]) |
| | | return res |
| | | |
| | | |
| | | def create_powerlabel(label, mapping_dict, max_speaker_num=6, max_olp_speaker_num=3): |
| | | T, C = label.shape |
| | | padding_label = np.zeros((T, max_speaker_num)) |
| | | padding_label[:, :C] = label |
| | | out_label = dec_trans(padding_label, max_speaker_num, mapping_dict) |
| | | out_label = torch.from_numpy(out_label) |
| | | return out_label |
| | | |
| | | |
| | | def generate_perm_pse(label, n_speaker, mapping_dict, max_speaker_num, max_olp_speaker_num=3): |
| | | perms = np.array(list(permutations(range(n_speaker)))).astype(np.float32) |
| | | perms = torch.from_numpy(perms).to(label.device).to(torch.int64) |
| | | perm_labels = [label[:, perm] for perm in perms] |
| | | perm_pse_labels = [create_powerlabel(perm_label.cpu().numpy(), mapping_dict, max_speaker_num). |
| | | to(perm_label.device, non_blocking=True) for perm_label in perm_labels] |
| | | return perm_labels, perm_pse_labels |
| | | |
| | | |
| | | def generate_min_pse(label, n_speaker, mapping_dict, max_speaker_num, pse_logit, max_olp_speaker_num=3): |
| | | perm_labels, perm_pse_labels = generate_perm_pse(label, n_speaker, mapping_dict, max_speaker_num, |
| | | max_olp_speaker_num=max_olp_speaker_num) |
| | | losses = [F.cross_entropy(input=pse_logit, target=perm_pse_label.to(torch.long)) * len(pse_logit) |
| | | for perm_pse_label in perm_pse_labels] |
| | | loss = torch.stack(losses) |
| | | min_index = torch.argmin(loss) |
| | | selected_perm_label, selected_pse_label = perm_labels[min_index], perm_pse_labels[min_index] |
| | | return selected_perm_label, selected_pse_label |