From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/metrics/compute_min_dcf.py |  101 +++++++++++++++++++++++++++++---------------------
 1 files changed, 59 insertions(+), 42 deletions(-)

diff --git a/funasr/metrics/compute_min_dcf.py b/funasr/metrics/compute_min_dcf.py
index 610113a..472f7b3 100644
--- a/funasr/metrics/compute_min_dcf.py
+++ b/funasr/metrics/compute_min_dcf.py
@@ -16,27 +16,45 @@
 
 
 def GetArgs():
-    parser = argparse.ArgumentParser(description="Compute the minimum "
-                                                 "detection cost function along with the threshold at which it occurs. "
-                                                 "Usage: sid/compute_min_dcf.py [options...] <scores-file> "
-                                                 "<trials-file> "
-                                                 "E.g., sid/compute_min_dcf.py --p-target 0.01 --c-miss 1 --c-fa 1 "
-                                                 "exp/scores/trials data/test/trials",
-                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
-    parser.add_argument('--p-target', type=float, dest="p_target",
-                        default=0.01,
-                        help='The prior probability of the target speaker in a trial.')
-    parser.add_argument('--c-miss', type=float, dest="c_miss", default=1,
-                        help='Cost of a missed detection.  This is usually not changed.')
-    parser.add_argument('--c-fa', type=float, dest="c_fa", default=1,
-                        help='Cost of a spurious detection.  This is usually not changed.')
-    parser.add_argument("scores_filename",
-                        help="Input scores file, with columns of the form "
-                             "<utt1> <utt2> <score>")
-    parser.add_argument("trials_filename",
-                        help="Input trials file, with columns of the form "
-                             "<utt1> <utt2> <target/nontarget>")
-    sys.stderr.write(' '.join(sys.argv) + "\n")
+    parser = argparse.ArgumentParser(
+        description="Compute the minimum "
+        "detection cost function along with the threshold at which it occurs. "
+        "Usage: sid/compute_min_dcf.py [options...] <scores-file> "
+        "<trials-file> "
+        "E.g., sid/compute_min_dcf.py --p-target 0.01 --c-miss 1 --c-fa 1 "
+        "exp/scores/trials data/test/trials",
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+    )
+    parser.add_argument(
+        "--p-target",
+        type=float,
+        dest="p_target",
+        default=0.01,
+        help="The prior probability of the target speaker in a trial.",
+    )
+    parser.add_argument(
+        "--c-miss",
+        type=float,
+        dest="c_miss",
+        default=1,
+        help="Cost of a missed detection.  This is usually not changed.",
+    )
+    parser.add_argument(
+        "--c-fa",
+        type=float,
+        dest="c_fa",
+        default=1,
+        help="Cost of a spurious detection.  This is usually not changed.",
+    )
+    parser.add_argument(
+        "scores_filename",
+        help="Input scores file, with columns of the form " "<utt1> <utt2> <score>",
+    )
+    parser.add_argument(
+        "trials_filename",
+        help="Input trials file, with columns of the form " "<utt1> <utt2> <target/nontarget>",
+    )
+    sys.stderr.write(" ".join(sys.argv) + "\n")
     args = parser.parse_args()
     args = CheckArgs(args)
     return args
@@ -44,11 +62,11 @@
 
 def CheckArgs(args):
     if args.c_fa <= 0:
-      raise Exception("--c-fa must be greater than 0")
+        raise Exception("--c-fa must be greater than 0")
     if args.c_miss <= 0:
-      raise Exception("--c-miss must be greater than 0")
+        raise Exception("--c-miss must be greater than 0")
     if args.p_target <= 0 or args.p_target >= 1:
-      raise Exception("--p-target must be greater than 0 and less than 1")
+        raise Exception("--p-target must be greater than 0 and less than 1")
     return args
 
 
@@ -59,9 +77,9 @@
     # Sort the scores from smallest to largest, and also get the corresponding
     # indexes of the sorted scores.  We will treat the sorted scores as the
     # thresholds at which the the error-rates are evaluated.
-    sorted_indexes, thresholds = zip(*sorted(
-        [(index, threshold) for index, threshold in enumerate(scores)],
-        key=itemgetter(1)))
+    sorted_indexes, thresholds = zip(
+        *sorted([(index, threshold) for index, threshold in enumerate(scores)], key=itemgetter(1))
+    )
     labels = [labels[i] for i in sorted_indexes]
     fns = []
     tns = []
@@ -75,18 +93,18 @@
             fns.append(labels[i])
             tns.append(1 - labels[i])
         else:
-            fns.append(fns[i-1] + labels[i])
-            tns.append(tns[i-1] + 1 - labels[i])
+            fns.append(fns[i - 1] + labels[i])
+            tns.append(tns[i - 1] + 1 - labels[i])
     positives = sum(labels)
     negatives = len(labels) - positives
 
-    # Now divide the false negatives by the total number of 
+    # Now divide the false negatives by the total number of
     # positives to obtain the false negative rates across
     # all thresholds
     fnrs = [fn / float(positives) for fn in fns]
 
-    # Divide the true negatives by the total number of 
-    # negatives to get the true negative rate. Subtract these 
+    # Divide the true negatives by the total number of
+    # negatives to get the true negative rate. Subtract these
     # quantities from 1 to get the false positive rates.
     fprs = [1 - tn / float(negatives) for tn in tns]
     return fnrs, fprs, thresholds
@@ -111,8 +129,8 @@
 
 
 def compute_min_dcf(scores_filename, trials_filename, c_miss=1, c_fa=1, p_target=0.01):
-    scores_file = open(scores_filename, 'r').readlines()
-    trials_file = open(trials_filename, 'r').readlines()
+    scores_file = open(scores_filename, "r").readlines()
+    trials_file = open(trials_filename, "r").readlines()
     c_miss = c_miss
     c_fa = c_fa
     p_target = p_target
@@ -136,23 +154,22 @@
             else:
                 labels.append(0)
         else:
-            raise Exception("Missing entry for " + utt1 + " and " + utt2
-                            + " " + scores_filename)
+            raise Exception("Missing entry for " + utt1 + " and " + utt2 + " " + scores_filename)
 
     fnrs, fprs, thresholds = ComputeErrorRates(scores, labels)
-    mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, p_target,
-                                      c_miss, c_fa)
+    mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa)
     return mindcf, threshold
 
 
 def main():
     args = GetArgs()
     mindcf, threshold = compute_min_dcf(
-        args.scores_filename, args.trials_filename,
-        args.c_miss, args.c_fa, args.p_target
+        args.scores_filename, args.trials_filename, args.c_miss, args.c_fa, args.p_target
     )
-    sys.stdout.write("minDCF is {0:.4f} at threshold {1:.4f} (p-target={2}, c-miss={3}, "
-                     "c-fa={4})\n".format(mindcf, threshold, args.p_target, args.c_miss, args.c_fa))
+    sys.stdout.write(
+        "minDCF is {0:.4f} at threshold {1:.4f} (p-target={2}, c-miss={3}, "
+        "c-fa={4})\n".format(mindcf, threshold, args.p_target, args.c_miss, args.c_fa)
+    )
 
 
 if __name__ == "__main__":

--
Gitblit v1.9.1