liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/metrics/compute_min_dcf.py
@@ -16,27 +16,45 @@
def GetArgs():
    parser = argparse.ArgumentParser(description="Compute the minimum "
                                                 "detection cost function along with the threshold at which it occurs. "
                                                 "Usage: sid/compute_min_dcf.py [options...] <scores-file> "
                                                 "<trials-file> "
                                                 "E.g., sid/compute_min_dcf.py --p-target 0.01 --c-miss 1 --c-fa 1 "
                                                 "exp/scores/trials data/test/trials",
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--p-target', type=float, dest="p_target",
                        default=0.01,
                        help='The prior probability of the target speaker in a trial.')
    parser.add_argument('--c-miss', type=float, dest="c_miss", default=1,
                        help='Cost of a missed detection.  This is usually not changed.')
    parser.add_argument('--c-fa', type=float, dest="c_fa", default=1,
                        help='Cost of a spurious detection.  This is usually not changed.')
    parser.add_argument("scores_filename",
                        help="Input scores file, with columns of the form "
                             "<utt1> <utt2> <score>")
    parser.add_argument("trials_filename",
                        help="Input trials file, with columns of the form "
                             "<utt1> <utt2> <target/nontarget>")
    sys.stderr.write(' '.join(sys.argv) + "\n")
    parser = argparse.ArgumentParser(
        description="Compute the minimum "
        "detection cost function along with the threshold at which it occurs. "
        "Usage: sid/compute_min_dcf.py [options...] <scores-file> "
        "<trials-file> "
        "E.g., sid/compute_min_dcf.py --p-target 0.01 --c-miss 1 --c-fa 1 "
        "exp/scores/trials data/test/trials",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "--p-target",
        type=float,
        dest="p_target",
        default=0.01,
        help="The prior probability of the target speaker in a trial.",
    )
    parser.add_argument(
        "--c-miss",
        type=float,
        dest="c_miss",
        default=1,
        help="Cost of a missed detection.  This is usually not changed.",
    )
    parser.add_argument(
        "--c-fa",
        type=float,
        dest="c_fa",
        default=1,
        help="Cost of a spurious detection.  This is usually not changed.",
    )
    parser.add_argument(
        "scores_filename",
        help="Input scores file, with columns of the form " "<utt1> <utt2> <score>",
    )
    parser.add_argument(
        "trials_filename",
        help="Input trials file, with columns of the form " "<utt1> <utt2> <target/nontarget>",
    )
    sys.stderr.write(" ".join(sys.argv) + "\n")
    args = parser.parse_args()
    args = CheckArgs(args)
    return args
@@ -44,11 +62,11 @@
def CheckArgs(args):
    if args.c_fa <= 0:
      raise Exception("--c-fa must be greater than 0")
        raise Exception("--c-fa must be greater than 0")
    if args.c_miss <= 0:
      raise Exception("--c-miss must be greater than 0")
        raise Exception("--c-miss must be greater than 0")
    if args.p_target <= 0 or args.p_target >= 1:
      raise Exception("--p-target must be greater than 0 and less than 1")
        raise Exception("--p-target must be greater than 0 and less than 1")
    return args
@@ -59,9 +77,9 @@
    # Sort the scores from smallest to largest, and also get the corresponding
    # indexes of the sorted scores.  We will treat the sorted scores as the
    # thresholds at which the the error-rates are evaluated.
    sorted_indexes, thresholds = zip(*sorted(
        [(index, threshold) for index, threshold in enumerate(scores)],
        key=itemgetter(1)))
    sorted_indexes, thresholds = zip(
        *sorted([(index, threshold) for index, threshold in enumerate(scores)], key=itemgetter(1))
    )
    labels = [labels[i] for i in sorted_indexes]
    fns = []
    tns = []
@@ -75,18 +93,18 @@
            fns.append(labels[i])
            tns.append(1 - labels[i])
        else:
            fns.append(fns[i-1] + labels[i])
            tns.append(tns[i-1] + 1 - labels[i])
            fns.append(fns[i - 1] + labels[i])
            tns.append(tns[i - 1] + 1 - labels[i])
    positives = sum(labels)
    negatives = len(labels) - positives
    # Now divide the false negatives by the total number of
    # Now divide the false negatives by the total number of
    # positives to obtain the false negative rates across
    # all thresholds
    fnrs = [fn / float(positives) for fn in fns]
    # Divide the true negatives by the total number of
    # negatives to get the true negative rate. Subtract these
    # Divide the true negatives by the total number of
    # negatives to get the true negative rate. Subtract these
    # quantities from 1 to get the false positive rates.
    fprs = [1 - tn / float(negatives) for tn in tns]
    return fnrs, fprs, thresholds
@@ -111,8 +129,8 @@
def compute_min_dcf(scores_filename, trials_filename, c_miss=1, c_fa=1, p_target=0.01):
    scores_file = open(scores_filename, 'r').readlines()
    trials_file = open(trials_filename, 'r').readlines()
    scores_file = open(scores_filename, "r").readlines()
    trials_file = open(trials_filename, "r").readlines()
    c_miss = c_miss
    c_fa = c_fa
    p_target = p_target
@@ -136,23 +154,22 @@
            else:
                labels.append(0)
        else:
            raise Exception("Missing entry for " + utt1 + " and " + utt2
                            + " " + scores_filename)
            raise Exception("Missing entry for " + utt1 + " and " + utt2 + " " + scores_filename)
    fnrs, fprs, thresholds = ComputeErrorRates(scores, labels)
    mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, p_target,
                                      c_miss, c_fa)
    mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa)
    return mindcf, threshold
def main():
    args = GetArgs()
    mindcf, threshold = compute_min_dcf(
        args.scores_filename, args.trials_filename,
        args.c_miss, args.c_fa, args.p_target
        args.scores_filename, args.trials_filename, args.c_miss, args.c_fa, args.p_target
    )
    sys.stdout.write("minDCF is {0:.4f} at threshold {1:.4f} (p-target={2}, c-miss={3}, "
                     "c-fa={4})\n".format(mindcf, threshold, args.p_target, args.c_miss, args.c_fa))
    sys.stdout.write(
        "minDCF is {0:.4f} at threshold {1:.4f} (p-target={2}, c-miss={3}, "
        "c-fa={4})\n".format(mindcf, threshold, args.p_target, args.c_miss, args.c_fa)
    )
if __name__ == "__main__":