Shi Xian
2024-12-05 0efc87352ce7d3903dbdedbfa5d01ca5e1cb19e7
funasr/models/eend/utils/power.py
@@ -20,14 +20,14 @@
    all_kinds_order = sorted(all_kinds)
    mapping_dict = {}
    mapping_dict['dec2label'] = {}
    mapping_dict['label2dec'] = {}
    mapping_dict["dec2label"] = {}
    mapping_dict["label2dec"] = {}
    for i in range(len(all_kinds_order)):
        dec = all_kinds_order[i]
        mapping_dict['dec2label'][dec] = i
        mapping_dict['label2dec'][i] = dec
        mapping_dict["dec2label"][dec] = i
        mapping_dict["label2dec"][i] = dec
    oov_id = len(all_kinds_order)
    mapping_dict['oov'] = oov_id
    mapping_dict["oov"] = oov_id
    return mapping_dict
@@ -45,10 +45,10 @@
def mapping_func(num, mapping_dict):
    if num in mapping_dict['dec2label'].keys():
        label = mapping_dict['dec2label'][num]
    if num in mapping_dict["dec2label"].keys():
        label = mapping_dict["dec2label"][num]
    else:
        label = mapping_dict['oov']
        label = mapping_dict["oov"]
    return label
@@ -79,16 +79,25 @@
    perms = np.array(list(permutations(range(n_speaker)))).astype(np.float32)
    perms = torch.from_numpy(perms).to(label.device).to(torch.int64)
    perm_labels = [label[:, perm] for perm in perms]
    perm_pse_labels = [create_powerlabel(perm_label.cpu().numpy(), mapping_dict, max_speaker_num).
                           to(perm_label.device, non_blocking=True) for perm_label in perm_labels]
    perm_pse_labels = [
        create_powerlabel(perm_label.cpu().numpy(), mapping_dict, max_speaker_num).to(
            perm_label.device, non_blocking=True
        )
        for perm_label in perm_labels
    ]
    return perm_labels, perm_pse_labels
def generate_min_pse(label, n_speaker, mapping_dict, max_speaker_num, pse_logit, max_olp_speaker_num=3):
    perm_labels, perm_pse_labels = generate_perm_pse(label, n_speaker, mapping_dict, max_speaker_num,
                                                     max_olp_speaker_num=max_olp_speaker_num)
    losses = [F.cross_entropy(input=pse_logit, target=perm_pse_label.to(torch.long)) * len(pse_logit)
              for perm_pse_label in perm_pse_labels]
def generate_min_pse(
    label, n_speaker, mapping_dict, max_speaker_num, pse_logit, max_olp_speaker_num=3
):
    perm_labels, perm_pse_labels = generate_perm_pse(
        label, n_speaker, mapping_dict, max_speaker_num, max_olp_speaker_num=max_olp_speaker_num
    )
    losses = [
        F.cross_entropy(input=pse_logit, target=perm_pse_label.to(torch.long)) * len(pse_logit)
        for perm_pse_label in perm_pse_labels
    ]
    loss = torch.stack(losses)
    min_index = torch.argmin(loss)
    selected_perm_label, selected_pse_label = perm_labels[min_index], perm_pse_labels[min_index]