liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
examples/industrial_data_pretraining/lcbnet/compute_wer_details.py
@@ -9,6 +9,7 @@
from tqdm import tqdm
import os
import pdb
remove_tag = False
spacelist = [" ", "\t", "\r", "\n"]
puncts = [
@@ -51,9 +52,9 @@
    def get_wer(self):
        assert self.ref_words != 0
        errors = (
                self.errors[Code.substitution]
                + self.errors[Code.insertion]
                + self.errors[Code.deletion]
            self.errors[Code.substitution]
            + self.errors[Code.insertion]
            + self.errors[Code.deletion]
        )
        return 100.0 * errors / self.ref_words
@@ -299,30 +300,30 @@
    for i in reversed(range(len(unicode_names))):
        if unicode_names[i].startswith("DIGIT"):  # 1
            unicode_names[i] = "Number"  # 'DIGIT'
        elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[
            i
        ].startswith("CJK COMPATIBILITY IDEOGRAPH"):
        elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[i].startswith(
            "CJK COMPATIBILITY IDEOGRAPH"
        ):
            # 明 / 郎
            unicode_names[i] = "Mandarin"  # 'CJK IDEOGRAPH'
        elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[
            i
        ].startswith("LATIN SMALL LETTER"):
        elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[i].startswith(
            "LATIN SMALL LETTER"
        ):
            # A / a
            unicode_names[i] = "English"  # 'LATIN LETTER'
        elif unicode_names[i].startswith("HIRAGANA LETTER"):  # は こ め
            unicode_names[i] = "Japanese"  # 'GANA LETTER'
        elif (
                unicode_names[i].startswith("AMPERSAND")
                or unicode_names[i].startswith("APOSTROPHE")
                or unicode_names[i].startswith("COMMERCIAL AT")
                or unicode_names[i].startswith("DEGREE CELSIUS")
                or unicode_names[i].startswith("EQUALS SIGN")
                or unicode_names[i].startswith("FULL STOP")
                or unicode_names[i].startswith("HYPHEN-MINUS")
                or unicode_names[i].startswith("LOW LINE")
                or unicode_names[i].startswith("NUMBER SIGN")
                or unicode_names[i].startswith("PLUS SIGN")
                or unicode_names[i].startswith("SEMICOLON")
            unicode_names[i].startswith("AMPERSAND")
            or unicode_names[i].startswith("APOSTROPHE")
            or unicode_names[i].startswith("COMMERCIAL AT")
            or unicode_names[i].startswith("DEGREE CELSIUS")
            or unicode_names[i].startswith("EQUALS SIGN")
            or unicode_names[i].startswith("FULL STOP")
            or unicode_names[i].startswith("HYPHEN-MINUS")
            or unicode_names[i].startswith("LOW LINE")
            or unicode_names[i].startswith("NUMBER SIGN")
            or unicode_names[i].startswith("PLUS SIGN")
            or unicode_names[i].startswith("SEMICOLON")
        ):
            # & / ' / @ / ℃ / = / . / - / _ / # / + / ;
            del unicode_names[i]
@@ -411,11 +412,13 @@
                if len(array) == 0:
                    continue
                fid = array[0]
                rec_sets[rec_names[i]][fid] = normalize(array[1:], ignore_words, case_sensitive, split)
                rec_sets[rec_names[i]][fid] = normalize(
                    array[1:], ignore_words, case_sensitive, split
                )
        calculators_dict[rec_names[i]] = Calculator()
        ub_wer_dict[rec_names[i]] = {"u_wer": WordError(), "b_wer": WordError(), "wer": WordError()}
        hotwords_related_dict[rec_names[i]] = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}
        hotwords_related_dict[rec_names[i]] = {"tp": 0, "tn": 0, "fp": 0, "fn": 0}
        # tp: 热词在label里,同时在rec里
        # tn: 热词不在label里,同时不在rec里
        # fp: 热词不在label里,但是在rec里
@@ -431,21 +434,22 @@
        _file_total_len = int(pipe.read().strip())
    # compute error rate on the interaction of reference file and hyp file
    for line in tqdm(open(ref_file, 'r', encoding='utf-8'), total=_file_total_len):
    for line in tqdm(open(ref_file, "r", encoding="utf-8"), total=_file_total_len):
        if tochar:
            array = characterize(line)
        else:
            array = line.rstrip('\n').split()
        if len(array) == 0: continue
            array = line.rstrip("\n").split()
        if len(array) == 0:
            continue
        fid = array[0]
        lab = normalize(array[1:], ignore_words, case_sensitive, split)
        if verbose:
            print('\nutt: %s' % fid)
            print("\nutt: %s" % fid)
        ocr_text = ref_ocr_dict[fid]
        ocr_set = set(ocr_text)
        print('ocr: {}'.format(" ".join(ocr_text)))
        print("ocr: {}".format(" ".join(ocr_text)))
        list_match = []  # 指label里面在ocr里面的内容
        list_not_mathch = []
        tmp_error = 0
@@ -458,7 +462,7 @@
            else:
                tmp_match += 1
                list_match.append(lab[index])
        print('label in ocr: {}'.format(" ".join(list_match)))
        print("label in ocr: {}".format(" ".join(list_match)))
        # for each reco file
        base_wrong_ocr_wer = None
@@ -482,33 +486,44 @@
            result = calculators_dict[rec_name].calculate(lab.copy(), rec.copy())
            if verbose:
                if result['all'] != 0:
                    wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
                if result["all"] != 0:
                    wer = (
                        float(result["ins"] + result["sub"] + result["del"]) * 100.0 / result["all"]
                    )
                else:
                    wer = 0.0
            print('WER(%s): %4.2f %%' % (rec_name, wer), end=' ')
            print('N=%d C=%d S=%d D=%d I=%d' %
                  (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
            print("WER(%s): %4.2f %%" % (rec_name, wer), end=" ")
            print(
                "N=%d C=%d S=%d D=%d I=%d"
                % (result["all"], result["cor"], result["sub"], result["del"], result["ins"])
            )
            # print(result['rec'])
            wrong_rec_but_in_ocr = []
            for idx in range(len(result['lab'])):
                if result['lab'][idx] != "":
                    if result['lab'][idx] != result['rec'][idx].replace("<BIAS>", ""):
                        if result['lab'][idx] in list_match:
                            wrong_rec_but_in_ocr.append(result['lab'][idx])
            for idx in range(len(result["lab"])):
                if result["lab"][idx] != "":
                    if result["lab"][idx] != result["rec"][idx].replace("<BIAS>", ""):
                        if result["lab"][idx] in list_match:
                            wrong_rec_but_in_ocr.append(result["lab"][idx])
                            wrong_rec_but_in_ocr_dict[rec_name] += 1
            print('wrong_rec_but_in_ocr: {}'.format(" ".join(wrong_rec_but_in_ocr)))
            print("wrong_rec_but_in_ocr: {}".format(" ".join(wrong_rec_but_in_ocr)))
            if rec_name == "base":
                base_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
            if "ocr" in rec_name or "hot" in rec_name:
                ocr_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
                if ocr_wrong_ocr_wer < base_wrong_ocr_wer:
                    print("{} {} helps, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
                    print(
                        "{} {} helps, {} -> {}".format(
                            fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer
                        )
                    )
                elif ocr_wrong_ocr_wer > base_wrong_ocr_wer:
                    print("{} {} hurts, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
                    print(
                        "{} {} hurts, {} -> {}".format(
                            fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer
                        )
                    )
            # recall = 0
            # false_alarm = 0
@@ -537,11 +552,11 @@
                #     if badhotword == word:
                #         count += 1
                if count == 0:
                    hotwords_related_dict[rec_name]['tn'] += 1
                    hotwords_related_dict[rec_name]["tn"] += 1
                    _tn += 1
                    # fp: 0
                else:
                    hotwords_related_dict[rec_name]['fp'] += count
                    hotwords_related_dict[rec_name]["fp"] += count
                    _fp += count
                    # tn: 0
                # if badhotword in _rec_list:
@@ -553,23 +568,30 @@
                rec_count = len([word for word in _rec_list if hotword == word])
                # print(f"good {hotword} true_count: {true_count}, rec_count: {rec_count}")
                if rec_count == true_count:
                    hotwords_related_dict[rec_name]['tp'] += true_count
                    hotwords_related_dict[rec_name]["tp"] += true_count
                    _tp += true_count
                elif rec_count > true_count:
                    hotwords_related_dict[rec_name]['tp'] += true_count
                    hotwords_related_dict[rec_name]["tp"] += true_count
                    # fp: 不在label里,但是在rec里
                    hotwords_related_dict[rec_name]['fp'] += rec_count - true_count
                    hotwords_related_dict[rec_name]["fp"] += rec_count - true_count
                    _tp += true_count
                    _fp += rec_count - true_count
                else:
                    hotwords_related_dict[rec_name]['tp'] += rec_count
                    hotwords_related_dict[rec_name]["tp"] += rec_count
                    # fn: 热词在label里,但是不在rec里
                    hotwords_related_dict[rec_name]['fn'] += true_count - rec_count
                    hotwords_related_dict[rec_name]["fn"] += true_count - rec_count
                    _tp += rec_count
                    _fn += true_count - rec_count
            print("hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
                _tp, _tn, _fp, _fn, sum([_tp, _tn, _fp, _fn]), _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0
            ))
            print(
                "hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
                    _tp,
                    _tn,
                    _fp,
                    _fn,
                    sum([_tp, _tn, _fp, _fn]),
                    _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0,
                )
            )
            # if hotword in _rec_list:
            #     hotwords_related_dict[rec_name]['tp'] += 1
@@ -612,77 +634,89 @@
                        ub_wer_dict[rec_name]["u_wer"].errors[Code.insertion] += 1
            space = {}
            space['lab'] = []
            space['rec'] = []
            for idx in range(len(result['lab'])):
                len_lab = width(result['lab'][idx])
                len_rec = width(result['rec'][idx])
            space["lab"] = []
            space["rec"] = []
            for idx in range(len(result["lab"])):
                len_lab = width(result["lab"][idx])
                len_rec = width(result["rec"][idx])
                length = max(len_lab, len_rec)
                space['lab'].append(length - len_lab)
                space['rec'].append(length - len_rec)
            upper_lab = len(result['lab'])
            upper_rec = len(result['rec'])
                space["lab"].append(length - len_lab)
                space["rec"].append(length - len_rec)
            upper_lab = len(result["lab"])
            upper_rec = len(result["rec"])
            lab1, rec1 = 0, 0
            while lab1 < upper_lab or rec1 < upper_rec:
                if verbose > 1:
                    print('lab(%s):' % fid.encode('utf-8'), end=' ')
                    print("lab(%s):" % fid.encode("utf-8"), end=" ")
                else:
                    print('lab:', end=' ')
                    print("lab:", end=" ")
                lab2 = min(upper_lab, lab1 + max_words_per_line)
                for idx in range(lab1, lab2):
                    token = result['lab'][idx]
                    print('{token}'.format(token=token), end='')
                    for n in range(space['lab'][idx]):
                        print(padding_symbol, end='')
                    print(' ', end='')
                    token = result["lab"][idx]
                    print("{token}".format(token=token), end="")
                    for n in range(space["lab"][idx]):
                        print(padding_symbol, end="")
                    print(" ", end="")
                print()
                if verbose > 1:
                    print('rec(%s):' % fid.encode('utf-8'), end=' ')
                    print("rec(%s):" % fid.encode("utf-8"), end=" ")
                else:
                    print('rec:', end=' ')
                    print("rec:", end=" ")
                rec2 = min(upper_rec, rec1 + max_words_per_line)
                for idx in range(rec1, rec2):
                    token = result['rec'][idx]
                    print('{token}'.format(token=token), end='')
                    for n in range(space['rec'][idx]):
                        print(padding_symbol, end='')
                    print(' ', end='')
                    token = result["rec"][idx]
                    print("{token}".format(token=token), end="")
                    for n in range(space["rec"][idx]):
                        print(padding_symbol, end="")
                    print(" ", end="")
                print()
                # print('\n', end='\n')
                lab1 = lab2
                rec1 = rec2
        print('\n', end='\n')
        print("\n", end="\n")
        # break
    if verbose:
        print('===========================================================================')
        print("===========================================================================")
        print()
    print(wrong_rec_but_in_ocr_dict)
    for rec_name in rec_names:
        result = calculators_dict[rec_name].overall()
        if result['all'] != 0:
            wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
        if result["all"] != 0:
            wer = float(result["ins"] + result["sub"] + result["del"]) * 100.0 / result["all"]
        else:
            wer = 0.0
        print('{} Overall -> {:4.2f} %'.format(rec_name, wer), end=' ')
        print('N=%d C=%d S=%d D=%d I=%d' %
              (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
        print("{} Overall -> {:4.2f} %".format(rec_name, wer), end=" ")
        print(
            "N=%d C=%d S=%d D=%d I=%d"
            % (result["all"], result["cor"], result["sub"], result["del"], result["ins"])
        )
        print(f"WER: {ub_wer_dict[rec_name]['wer'].get_result_string()}")
        print(f"U-WER: {ub_wer_dict[rec_name]['u_wer'].get_result_string()}")
        print(f"B-WER: {ub_wer_dict[rec_name]['b_wer'].get_result_string()}")
        print('hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%'.format(
            hotwords_related_dict[rec_name]['tp'],
            hotwords_related_dict[rec_name]['tn'],
            hotwords_related_dict[rec_name]['fp'],
            hotwords_related_dict[rec_name]['fn'],
            sum([v for k, v in hotwords_related_dict[rec_name].items()]),
            hotwords_related_dict[rec_name]['tp'] / (
                    hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn']
            ) * 100 if hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] != 0 else 0
        ))
        print(
            "hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
                hotwords_related_dict[rec_name]["tp"],
                hotwords_related_dict[rec_name]["tn"],
                hotwords_related_dict[rec_name]["fp"],
                hotwords_related_dict[rec_name]["fn"],
                sum([v for k, v in hotwords_related_dict[rec_name].items()]),
                (
                    hotwords_related_dict[rec_name]["tp"]
                    / (
                        hotwords_related_dict[rec_name]["tp"]
                        + hotwords_related_dict[rec_name]["fn"]
                    )
                    * 100
                    if hotwords_related_dict[rec_name]["tp"] + hotwords_related_dict[rec_name]["fn"]
                    != 0
                    else 0
                ),
            )
        )
        # tp: 热词在label里,同时在rec里
        # tn: 热词不在label里,同时不在rec里
@@ -695,8 +729,7 @@
if __name__ == "__main__":
    args = get_args()
    # print("")
    print(args)
    main(args)