kongdeqiang
5 天以前 28ccfbfc51068a663a80764e14074df5edf2b5ba
funasr/metrics/common.py
@@ -36,9 +36,9 @@
        hyp_length = i - m
        hyps_same_length = [x for x in ended_hyps if len(x["yseq"]) == hyp_length]
        if len(hyps_same_length) > 0:
            best_hyp_same_length = sorted(
                hyps_same_length, key=lambda x: x["score"], reverse=True
            )[0]
            best_hyp_same_length = sorted(hyps_same_length, key=lambda x: x["score"], reverse=True)[
                0
            ]
            if best_hyp_same_length["score"] - best_hyp["score"] < D_end:
                count += 1
@@ -63,9 +63,7 @@
            trans_json = json.load(f)["utts"]
    if lsm_type == "unigram":
        assert transcript is not None, (
            "transcript is required for %s label smoothing" % lsm_type
        )
        assert transcript is not None, "transcript is required for %s label smoothing" % lsm_type
        labelcount = np.zeros(odim)
        for k, v in trans_json.items():
            ids = np.array([int(n) for n in v["output"][0]["tokenid"].split()])
@@ -108,9 +106,7 @@
    :return:
    """
    def __init__(
        self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False
    ):
    def __init__(self, char_list, sym_space, sym_blank, report_cer=False, report_wer=False):
        """Construct an ErrorCalculator object."""
        super(ErrorCalculator, self).__init__()