From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/metrics/compute_min_dcf.py | 101 +++++++++++++++++++++++++++++---------------------
1 files changed, 59 insertions(+), 42 deletions(-)
diff --git a/funasr/metrics/compute_min_dcf.py b/funasr/metrics/compute_min_dcf.py
index 610113a..472f7b3 100644
--- a/funasr/metrics/compute_min_dcf.py
+++ b/funasr/metrics/compute_min_dcf.py
@@ -16,27 +16,45 @@
def GetArgs():
- parser = argparse.ArgumentParser(description="Compute the minimum "
- "detection cost function along with the threshold at which it occurs. "
- "Usage: sid/compute_min_dcf.py [options...] <scores-file> "
- "<trials-file> "
- "E.g., sid/compute_min_dcf.py --p-target 0.01 --c-miss 1 --c-fa 1 "
- "exp/scores/trials data/test/trials",
- formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument('--p-target', type=float, dest="p_target",
- default=0.01,
- help='The prior probability of the target speaker in a trial.')
- parser.add_argument('--c-miss', type=float, dest="c_miss", default=1,
- help='Cost of a missed detection. This is usually not changed.')
- parser.add_argument('--c-fa', type=float, dest="c_fa", default=1,
- help='Cost of a spurious detection. This is usually not changed.')
- parser.add_argument("scores_filename",
- help="Input scores file, with columns of the form "
- "<utt1> <utt2> <score>")
- parser.add_argument("trials_filename",
- help="Input trials file, with columns of the form "
- "<utt1> <utt2> <target/nontarget>")
- sys.stderr.write(' '.join(sys.argv) + "\n")
+ parser = argparse.ArgumentParser(
+ description="Compute the minimum "
+ "detection cost function along with the threshold at which it occurs. "
+ "Usage: sid/compute_min_dcf.py [options...] <scores-file> "
+ "<trials-file> "
+ "E.g., sid/compute_min_dcf.py --p-target 0.01 --c-miss 1 --c-fa 1 "
+ "exp/scores/trials data/test/trials",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+ parser.add_argument(
+ "--p-target",
+ type=float,
+ dest="p_target",
+ default=0.01,
+ help="The prior probability of the target speaker in a trial.",
+ )
+ parser.add_argument(
+ "--c-miss",
+ type=float,
+ dest="c_miss",
+ default=1,
+ help="Cost of a missed detection. This is usually not changed.",
+ )
+ parser.add_argument(
+ "--c-fa",
+ type=float,
+ dest="c_fa",
+ default=1,
+ help="Cost of a spurious detection. This is usually not changed.",
+ )
+ parser.add_argument(
+ "scores_filename",
+ help="Input scores file, with columns of the form " "<utt1> <utt2> <score>",
+ )
+ parser.add_argument(
+ "trials_filename",
+ help="Input trials file, with columns of the form " "<utt1> <utt2> <target/nontarget>",
+ )
+ sys.stderr.write(" ".join(sys.argv) + "\n")
args = parser.parse_args()
args = CheckArgs(args)
return args
@@ -44,11 +62,11 @@
def CheckArgs(args):
if args.c_fa <= 0:
- raise Exception("--c-fa must be greater than 0")
+ raise Exception("--c-fa must be greater than 0")
if args.c_miss <= 0:
- raise Exception("--c-miss must be greater than 0")
+ raise Exception("--c-miss must be greater than 0")
if args.p_target <= 0 or args.p_target >= 1:
- raise Exception("--p-target must be greater than 0 and less than 1")
+ raise Exception("--p-target must be greater than 0 and less than 1")
return args
@@ -59,9 +77,9 @@
# Sort the scores from smallest to largest, and also get the corresponding
# indexes of the sorted scores. We will treat the sorted scores as the
# thresholds at which the the error-rates are evaluated.
- sorted_indexes, thresholds = zip(*sorted(
- [(index, threshold) for index, threshold in enumerate(scores)],
- key=itemgetter(1)))
+ sorted_indexes, thresholds = zip(
+ *sorted([(index, threshold) for index, threshold in enumerate(scores)], key=itemgetter(1))
+ )
labels = [labels[i] for i in sorted_indexes]
fns = []
tns = []
@@ -75,18 +93,18 @@
fns.append(labels[i])
tns.append(1 - labels[i])
else:
- fns.append(fns[i-1] + labels[i])
- tns.append(tns[i-1] + 1 - labels[i])
+ fns.append(fns[i - 1] + labels[i])
+ tns.append(tns[i - 1] + 1 - labels[i])
positives = sum(labels)
negatives = len(labels) - positives
- # Now divide the false negatives by the total number of
+ # Now divide the false negatives by the total number of
# positives to obtain the false negative rates across
# all thresholds
fnrs = [fn / float(positives) for fn in fns]
- # Divide the true negatives by the total number of
- # negatives to get the true negative rate. Subtract these
+ # Divide the true negatives by the total number of
+ # negatives to get the true negative rate. Subtract these
# quantities from 1 to get the false positive rates.
fprs = [1 - tn / float(negatives) for tn in tns]
return fnrs, fprs, thresholds
@@ -111,8 +129,8 @@
def compute_min_dcf(scores_filename, trials_filename, c_miss=1, c_fa=1, p_target=0.01):
- scores_file = open(scores_filename, 'r').readlines()
- trials_file = open(trials_filename, 'r').readlines()
+ scores_file = open(scores_filename, "r").readlines()
+ trials_file = open(trials_filename, "r").readlines()
c_miss = c_miss
c_fa = c_fa
p_target = p_target
@@ -136,23 +154,22 @@
else:
labels.append(0)
else:
- raise Exception("Missing entry for " + utt1 + " and " + utt2
- + " " + scores_filename)
+ raise Exception("Missing entry for " + utt1 + " and " + utt2 + " " + scores_filename)
fnrs, fprs, thresholds = ComputeErrorRates(scores, labels)
- mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, p_target,
- c_miss, c_fa)
+ mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa)
return mindcf, threshold
def main():
args = GetArgs()
mindcf, threshold = compute_min_dcf(
- args.scores_filename, args.trials_filename,
- args.c_miss, args.c_fa, args.p_target
+ args.scores_filename, args.trials_filename, args.c_miss, args.c_fa, args.p_target
)
- sys.stdout.write("minDCF is {0:.4f} at threshold {1:.4f} (p-target={2}, c-miss={3}, "
- "c-fa={4})\n".format(mindcf, threshold, args.p_target, args.c_miss, args.c_fa))
+ sys.stdout.write(
+ "minDCF is {0:.4f} at threshold {1:.4f} (p-target={2}, c-miss={3}, "
+ "c-fa={4})\n".format(mindcf, threshold, args.p_target, args.c_miss, args.c_fa)
+ )
if __name__ == "__main__":
--
Gitblit v1.9.1