speech_asr
2023-03-08 0892f5ce5240fde47fdcfc6f4faea8bfad6dc0ce
funasr/modules/eend_ola/utils/power.py
@@ -0,0 +1,95 @@
import numpy as np
import torch
import torch.multiprocessing
import torch.nn.functional as F
from itertools import combinations
from itertools import permutations
def generate_mapping_dict(max_speaker_num=6, max_olp_speaker_num=3):
    all_kinds = []
    all_kinds.append(0)
    for i in range(max_olp_speaker_num):
        selected_num = i + 1
        coms = np.array(list(combinations(np.arange(max_speaker_num), selected_num)))
        for com in coms:
            tmp = np.zeros(max_speaker_num)
            tmp[com] = 1
            item = int(raw_dec_trans(tmp.reshape(1, -1), max_speaker_num)[0])
            all_kinds.append(item)
    all_kinds_order = sorted(all_kinds)
    mapping_dict = {}
    mapping_dict['dec2label'] = {}
    mapping_dict['label2dec'] = {}
    for i in range(len(all_kinds_order)):
        dec = all_kinds_order[i]
        mapping_dict['dec2label'][dec] = i
        mapping_dict['label2dec'][i] = dec
    oov_id = len(all_kinds_order)
    mapping_dict['oov'] = oov_id
    return mapping_dict
def raw_dec_trans(x, max_speaker_num):
    num_list = []
    for i in range(max_speaker_num):
        num_list.append(x[:, i])
    base = 1
    T = x.shape[0]
    res = np.zeros((T))
    for num in num_list:
        res += num * base
        base = base * 2
    return res
def mapping_func(num, mapping_dict):
    if num in mapping_dict['dec2label'].keys():
        label = mapping_dict['dec2label'][num]
    else:
        label = mapping_dict['oov']
    return label
def dec_trans(x, max_speaker_num, mapping_dict):
    num_list = []
    for i in range(max_speaker_num):
        num_list.append(x[:, i])
    base = 1
    T = x.shape[0]
    res = np.zeros((T))
    for num in num_list:
        res += num * base
        base = base * 2
    res = np.array([mapping_func(i, mapping_dict) for i in res])
    return res
def create_powerlabel(label, mapping_dict, max_speaker_num=6, max_olp_speaker_num=3):
    T, C = label.shape
    padding_label = np.zeros((T, max_speaker_num))
    padding_label[:, :C] = label
    out_label = dec_trans(padding_label, max_speaker_num, mapping_dict)
    out_label = torch.from_numpy(out_label)
    return out_label
def generate_perm_pse(label, n_speaker, mapping_dict, max_speaker_num, max_olp_speaker_num=3):
    perms = np.array(list(permutations(range(n_speaker)))).astype(np.float32)
    perms = torch.from_numpy(perms).to(label.device).to(torch.int64)
    perm_labels = [label[:, perm] for perm in perms]
    perm_pse_labels = [create_powerlabel(perm_label.cpu().numpy(), mapping_dict, max_speaker_num).
                           to(perm_label.device, non_blocking=True) for perm_label in perm_labels]
    return perm_labels, perm_pse_labels
def generate_min_pse(label, n_speaker, mapping_dict, max_speaker_num, pse_logit, max_olp_speaker_num=3):
    perm_labels, perm_pse_labels = generate_perm_pse(label, n_speaker, mapping_dict, max_speaker_num,
                                                     max_olp_speaker_num=max_olp_speaker_num)
    losses = [F.cross_entropy(input=pse_logit, target=perm_pse_label.to(torch.long)) * len(pse_logit)
              for perm_pse_label in perm_pse_labels]
    loss = torch.stack(losses)
    min_index = torch.argmin(loss)
    selected_perm_label, selected_pse_label = perm_labels[min_index], perm_pse_labels[min_index]
    return selected_perm_label, selected_pse_label