| | |
| | | |
| | | |
| | | def GetArgs(): |
| | | parser = argparse.ArgumentParser(description="Compute the minimum " |
| | | "detection cost function along with the threshold at which it occurs. " |
| | | "Usage: sid/compute_min_dcf.py [options...] <scores-file> " |
| | | "<trials-file> " |
| | | "E.g., sid/compute_min_dcf.py --p-target 0.01 --c-miss 1 --c-fa 1 " |
| | | "exp/scores/trials data/test/trials", |
| | | formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
| | | parser.add_argument('--p-target', type=float, dest="p_target", |
| | | default=0.01, |
| | | help='The prior probability of the target speaker in a trial.') |
| | | parser.add_argument('--c-miss', type=float, dest="c_miss", default=1, |
| | | help='Cost of a missed detection. This is usually not changed.') |
| | | parser.add_argument('--c-fa', type=float, dest="c_fa", default=1, |
| | | help='Cost of a spurious detection. This is usually not changed.') |
| | | parser.add_argument("scores_filename", |
| | | help="Input scores file, with columns of the form " |
| | | "<utt1> <utt2> <score>") |
| | | parser.add_argument("trials_filename", |
| | | help="Input trials file, with columns of the form " |
| | | "<utt1> <utt2> <target/nontarget>") |
| | | sys.stderr.write(' '.join(sys.argv) + "\n") |
| | | parser = argparse.ArgumentParser( |
| | | description="Compute the minimum " |
| | | "detection cost function along with the threshold at which it occurs. " |
| | | "Usage: sid/compute_min_dcf.py [options...] <scores-file> " |
| | | "<trials-file> " |
| | | "E.g., sid/compute_min_dcf.py --p-target 0.01 --c-miss 1 --c-fa 1 " |
| | | "exp/scores/trials data/test/trials", |
| | | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| | | ) |
| | | parser.add_argument( |
| | | "--p-target", |
| | | type=float, |
| | | dest="p_target", |
| | | default=0.01, |
| | | help="The prior probability of the target speaker in a trial.", |
| | | ) |
| | | parser.add_argument( |
| | | "--c-miss", |
| | | type=float, |
| | | dest="c_miss", |
| | | default=1, |
| | | help="Cost of a missed detection. This is usually not changed.", |
| | | ) |
| | | parser.add_argument( |
| | | "--c-fa", |
| | | type=float, |
| | | dest="c_fa", |
| | | default=1, |
| | | help="Cost of a spurious detection. This is usually not changed.", |
| | | ) |
| | | parser.add_argument( |
| | | "scores_filename", |
| | | help="Input scores file, with columns of the form " "<utt1> <utt2> <score>", |
| | | ) |
| | | parser.add_argument( |
| | | "trials_filename", |
| | | help="Input trials file, with columns of the form " "<utt1> <utt2> <target/nontarget>", |
| | | ) |
| | | sys.stderr.write(" ".join(sys.argv) + "\n") |
| | | args = parser.parse_args() |
| | | args = CheckArgs(args) |
| | | return args |
| | |
| | | |
| | | def CheckArgs(args): |
| | | if args.c_fa <= 0: |
| | | raise Exception("--c-fa must be greater than 0") |
| | | raise Exception("--c-fa must be greater than 0") |
| | | if args.c_miss <= 0: |
| | | raise Exception("--c-miss must be greater than 0") |
| | | raise Exception("--c-miss must be greater than 0") |
| | | if args.p_target <= 0 or args.p_target >= 1: |
| | | raise Exception("--p-target must be greater than 0 and less than 1") |
| | | raise Exception("--p-target must be greater than 0 and less than 1") |
| | | return args |
| | | |
| | | |
| | |
| | | # Sort the scores from smallest to largest, and also get the corresponding |
| | | # indexes of the sorted scores. We will treat the sorted scores as the |
| | | # thresholds at which the the error-rates are evaluated. |
| | | sorted_indexes, thresholds = zip(*sorted( |
| | | [(index, threshold) for index, threshold in enumerate(scores)], |
| | | key=itemgetter(1))) |
| | | sorted_indexes, thresholds = zip( |
| | | *sorted([(index, threshold) for index, threshold in enumerate(scores)], key=itemgetter(1)) |
| | | ) |
| | | labels = [labels[i] for i in sorted_indexes] |
| | | fns = [] |
| | | tns = [] |
| | |
| | | fns.append(labels[i]) |
| | | tns.append(1 - labels[i]) |
| | | else: |
| | | fns.append(fns[i-1] + labels[i]) |
| | | tns.append(tns[i-1] + 1 - labels[i]) |
| | | fns.append(fns[i - 1] + labels[i]) |
| | | tns.append(tns[i - 1] + 1 - labels[i]) |
| | | positives = sum(labels) |
| | | negatives = len(labels) - positives |
| | | |
| | | # Now divide the false negatives by the total number of |
| | | # Now divide the false negatives by the total number of |
| | | # positives to obtain the false negative rates across |
| | | # all thresholds |
| | | fnrs = [fn / float(positives) for fn in fns] |
| | | |
| | | # Divide the true negatives by the total number of |
| | | # negatives to get the true negative rate. Subtract these |
| | | # Divide the true negatives by the total number of |
| | | # negatives to get the true negative rate. Subtract these |
| | | # quantities from 1 to get the false positive rates. |
| | | fprs = [1 - tn / float(negatives) for tn in tns] |
| | | return fnrs, fprs, thresholds |
| | |
| | | |
| | | |
| | | def compute_min_dcf(scores_filename, trials_filename, c_miss=1, c_fa=1, p_target=0.01): |
| | | scores_file = open(scores_filename, 'r').readlines() |
| | | trials_file = open(trials_filename, 'r').readlines() |
| | | scores_file = open(scores_filename, "r").readlines() |
| | | trials_file = open(trials_filename, "r").readlines() |
| | | c_miss = c_miss |
| | | c_fa = c_fa |
| | | p_target = p_target |
| | |
| | | else: |
| | | labels.append(0) |
| | | else: |
| | | raise Exception("Missing entry for " + utt1 + " and " + utt2 |
| | | + " " + scores_filename) |
| | | raise Exception("Missing entry for " + utt1 + " and " + utt2 + " " + scores_filename) |
| | | |
| | | fnrs, fprs, thresholds = ComputeErrorRates(scores, labels) |
| | | mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, p_target, |
| | | c_miss, c_fa) |
| | | mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa) |
| | | return mindcf, threshold |
| | | |
| | | |
| | | def main(): |
| | | args = GetArgs() |
| | | mindcf, threshold = compute_min_dcf( |
| | | args.scores_filename, args.trials_filename, |
| | | args.c_miss, args.c_fa, args.p_target |
| | | args.scores_filename, args.trials_filename, args.c_miss, args.c_fa, args.p_target |
| | | ) |
| | | sys.stdout.write("minDCF is {0:.4f} at threshold {1:.4f} (p-target={2}, c-miss={3}, " |
| | | "c-fa={4})\n".format(mindcf, threshold, args.p_target, args.c_miss, args.c_fa)) |
| | | sys.stdout.write( |
| | | "minDCF is {0:.4f} at threshold {1:.4f} (p-target={2}, c-miss={3}, " |
| | | "c-fa={4})\n".format(mindcf, threshold, args.p_target, args.c_miss, args.c_fa) |
| | | ) |
| | | |
| | | |
| | | if __name__ == "__main__": |