From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/eend/utils/power.py |   39 ++++++++++++++++++++++++---------------
 1 files changed, 24 insertions(+), 15 deletions(-)

diff --git a/funasr/models/eend/utils/power.py b/funasr/models/eend/utils/power.py
index 7144e24..b6b32c5 100644
--- a/funasr/models/eend/utils/power.py
+++ b/funasr/models/eend/utils/power.py
@@ -20,14 +20,14 @@
     all_kinds_order = sorted(all_kinds)
 
     mapping_dict = {}
-    mapping_dict['dec2label'] = {}
-    mapping_dict['label2dec'] = {}
+    mapping_dict["dec2label"] = {}
+    mapping_dict["label2dec"] = {}
     for i in range(len(all_kinds_order)):
         dec = all_kinds_order[i]
-        mapping_dict['dec2label'][dec] = i
-        mapping_dict['label2dec'][i] = dec
+        mapping_dict["dec2label"][dec] = i
+        mapping_dict["label2dec"][i] = dec
     oov_id = len(all_kinds_order)
-    mapping_dict['oov'] = oov_id
+    mapping_dict["oov"] = oov_id
     return mapping_dict
 
 
@@ -45,10 +45,10 @@
 
 
 def mapping_func(num, mapping_dict):
-    if num in mapping_dict['dec2label'].keys():
-        label = mapping_dict['dec2label'][num]
+    if num in mapping_dict["dec2label"].keys():
+        label = mapping_dict["dec2label"][num]
     else:
-        label = mapping_dict['oov']
+        label = mapping_dict["oov"]
     return label
 
 
@@ -79,16 +79,25 @@
     perms = np.array(list(permutations(range(n_speaker)))).astype(np.float32)
     perms = torch.from_numpy(perms).to(label.device).to(torch.int64)
     perm_labels = [label[:, perm] for perm in perms]
-    perm_pse_labels = [create_powerlabel(perm_label.cpu().numpy(), mapping_dict, max_speaker_num).
-                           to(perm_label.device, non_blocking=True) for perm_label in perm_labels]
+    perm_pse_labels = [
+        create_powerlabel(perm_label.cpu().numpy(), mapping_dict, max_speaker_num).to(
+            perm_label.device, non_blocking=True
+        )
+        for perm_label in perm_labels
+    ]
     return perm_labels, perm_pse_labels
 
 
-def generate_min_pse(label, n_speaker, mapping_dict, max_speaker_num, pse_logit, max_olp_speaker_num=3):
-    perm_labels, perm_pse_labels = generate_perm_pse(label, n_speaker, mapping_dict, max_speaker_num,
-                                                     max_olp_speaker_num=max_olp_speaker_num)
-    losses = [F.cross_entropy(input=pse_logit, target=perm_pse_label.to(torch.long)) * len(pse_logit)
-              for perm_pse_label in perm_pse_labels]
+def generate_min_pse(
+    label, n_speaker, mapping_dict, max_speaker_num, pse_logit, max_olp_speaker_num=3
+):
+    perm_labels, perm_pse_labels = generate_perm_pse(
+        label, n_speaker, mapping_dict, max_speaker_num, max_olp_speaker_num=max_olp_speaker_num
+    )
+    losses = [
+        F.cross_entropy(input=pse_logit, target=perm_pse_label.to(torch.long)) * len(pse_logit)
+        for perm_pse_label in perm_pse_labels
+    ]
     loss = torch.stack(losses)
     min_index = torch.argmin(loss)
     selected_perm_label, selected_pse_label = perm_labels[min_index], perm_pse_labels[min_index]

--
Gitblit v1.9.1