From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/eend/utils/power.py | 39 ++++++++++++++++++++++++---------------
1 files changed, 24 insertions(+), 15 deletions(-)
diff --git a/funasr/models/eend/utils/power.py b/funasr/models/eend/utils/power.py
index 7144e24..b6b32c5 100644
--- a/funasr/models/eend/utils/power.py
+++ b/funasr/models/eend/utils/power.py
@@ -20,14 +20,14 @@
all_kinds_order = sorted(all_kinds)
mapping_dict = {}
- mapping_dict['dec2label'] = {}
- mapping_dict['label2dec'] = {}
+ mapping_dict["dec2label"] = {}
+ mapping_dict["label2dec"] = {}
for i in range(len(all_kinds_order)):
dec = all_kinds_order[i]
- mapping_dict['dec2label'][dec] = i
- mapping_dict['label2dec'][i] = dec
+ mapping_dict["dec2label"][dec] = i
+ mapping_dict["label2dec"][i] = dec
oov_id = len(all_kinds_order)
- mapping_dict['oov'] = oov_id
+ mapping_dict["oov"] = oov_id
return mapping_dict
@@ -45,10 +45,10 @@
def mapping_func(num, mapping_dict):
- if num in mapping_dict['dec2label'].keys():
- label = mapping_dict['dec2label'][num]
+ if num in mapping_dict["dec2label"].keys():
+ label = mapping_dict["dec2label"][num]
else:
- label = mapping_dict['oov']
+ label = mapping_dict["oov"]
return label
@@ -79,16 +79,25 @@
perms = np.array(list(permutations(range(n_speaker)))).astype(np.float32)
perms = torch.from_numpy(perms).to(label.device).to(torch.int64)
perm_labels = [label[:, perm] for perm in perms]
- perm_pse_labels = [create_powerlabel(perm_label.cpu().numpy(), mapping_dict, max_speaker_num).
- to(perm_label.device, non_blocking=True) for perm_label in perm_labels]
+ perm_pse_labels = [
+ create_powerlabel(perm_label.cpu().numpy(), mapping_dict, max_speaker_num).to(
+ perm_label.device, non_blocking=True
+ )
+ for perm_label in perm_labels
+ ]
return perm_labels, perm_pse_labels
-def generate_min_pse(label, n_speaker, mapping_dict, max_speaker_num, pse_logit, max_olp_speaker_num=3):
- perm_labels, perm_pse_labels = generate_perm_pse(label, n_speaker, mapping_dict, max_speaker_num,
- max_olp_speaker_num=max_olp_speaker_num)
- losses = [F.cross_entropy(input=pse_logit, target=perm_pse_label.to(torch.long)) * len(pse_logit)
- for perm_pse_label in perm_pse_labels]
+def generate_min_pse(
+ label, n_speaker, mapping_dict, max_speaker_num, pse_logit, max_olp_speaker_num=3
+):
+ perm_labels, perm_pse_labels = generate_perm_pse(
+ label, n_speaker, mapping_dict, max_speaker_num, max_olp_speaker_num=max_olp_speaker_num
+ )
+ losses = [
+ F.cross_entropy(input=pse_logit, target=perm_pse_label.to(torch.long)) * len(pse_logit)
+ for perm_pse_label in perm_pse_labels
+ ]
loss = torch.stack(losses)
min_index = torch.argmin(loss)
selected_perm_label, selected_pse_label = perm_labels[min_index], perm_pse_labels[min_index]
--
Gitblit v1.9.1