| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # -*- coding: utf-8 -*- |
| | | |
| | | |
| | | from enum import Enum |
| | | import re, sys, unicodedata |
| | | import codecs |
| | | import argparse |
| | | from tqdm import tqdm |
| | | import os |
| | | import pdb |
| | | remove_tag = False |
| | | spacelist = [" ", "\t", "\r", "\n"] |
| | | puncts = [ |
| | | "!", |
| | | ",", |
| | | "?", |
| | | "、", |
| | | "。", |
| | | "!", |
| | | ",", |
| | | ";", |
| | | "?", |
| | | ":", |
| | | "「", |
| | | "」", |
| | | "︰", |
| | | "『", |
| | | "』", |
| | | "《", |
| | | "》", |
| | | ] |
| | | |
| | | |
| | | class Code(Enum): |
| | | match = 1 |
| | | substitution = 2 |
| | | insertion = 3 |
| | | deletion = 4 |
| | | |
| | | |
| | | class WordError(object): |
| | | def __init__(self): |
| | | self.errors = { |
| | | Code.substitution: 0, |
| | | Code.insertion: 0, |
| | | Code.deletion: 0, |
| | | } |
| | | self.ref_words = 0 |
| | | |
| | | def get_wer(self): |
| | | assert self.ref_words != 0 |
| | | errors = ( |
| | | self.errors[Code.substitution] |
| | | + self.errors[Code.insertion] |
| | | + self.errors[Code.deletion] |
| | | ) |
| | | return 100.0 * errors / self.ref_words |
| | | |
| | | def get_result_string(self): |
| | | return ( |
| | | f"error_rate={self.get_wer():.4f}, " |
| | | f"ref_words={self.ref_words}, " |
| | | f"subs={self.errors[Code.substitution]}, " |
| | | f"ins={self.errors[Code.insertion]}, " |
| | | f"dels={self.errors[Code.deletion]}" |
| | | ) |
| | | |
| | | |
| | | def characterize(string): |
| | | res = [] |
| | | i = 0 |
| | | while i < len(string): |
| | | char = string[i] |
| | | if char in puncts: |
| | | i += 1 |
| | | continue |
| | | cat1 = unicodedata.category(char) |
| | | # https://unicodebook.readthedocs.io/unicode.html#unicode-categories |
| | | if cat1 == "Zs" or cat1 == "Cn" or char in spacelist: # space or not assigned |
| | | i += 1 |
| | | continue |
| | | if cat1 == "Lo": # letter-other |
| | | res.append(char) |
| | | i += 1 |
| | | else: |
| | | # some input looks like: <unk><noise>, we want to separate it to two words. |
| | | sep = " " |
| | | if char == "<": |
| | | sep = ">" |
| | | j = i + 1 |
| | | while j < len(string): |
| | | c = string[j] |
| | | if ord(c) >= 128 or (c in spacelist) or (c == sep): |
| | | break |
| | | j += 1 |
| | | if j < len(string) and string[j] == ">": |
| | | j += 1 |
| | | res.append(string[i:j]) |
| | | i = j |
| | | return res |
| | | |
| | | |
| | | def stripoff_tags(x): |
| | | if not x: |
| | | return "" |
| | | chars = [] |
| | | i = 0 |
| | | T = len(x) |
| | | while i < T: |
| | | if x[i] == "<": |
| | | while i < T and x[i] != ">": |
| | | i += 1 |
| | | i += 1 |
| | | else: |
| | | chars.append(x[i]) |
| | | i += 1 |
| | | return "".join(chars) |
| | | |
| | | |
| | | def normalize(sentence, ignore_words, cs, split=None): |
| | | """sentence, ignore_words are both in unicode""" |
| | | new_sentence = [] |
| | | for token in sentence: |
| | | x = token |
| | | if not cs: |
| | | x = x.upper() |
| | | if x in ignore_words: |
| | | continue |
| | | if remove_tag: |
| | | x = stripoff_tags(x) |
| | | if not x: |
| | | continue |
| | | if split and x in split: |
| | | new_sentence += split[x] |
| | | else: |
| | | new_sentence.append(x) |
| | | return new_sentence |
| | | |
| | | |
| | | class Calculator: |
| | | def __init__(self): |
| | | self.data = {} |
| | | self.space = [] |
| | | self.cost = {} |
| | | self.cost["cor"] = 0 |
| | | self.cost["sub"] = 1 |
| | | self.cost["del"] = 1 |
| | | self.cost["ins"] = 1 |
| | | |
| | | def calculate(self, lab, rec): |
| | | # Initialization |
| | | lab.insert(0, "") |
| | | rec.insert(0, "") |
| | | while len(self.space) < len(lab): |
| | | self.space.append([]) |
| | | for row in self.space: |
| | | for element in row: |
| | | element["dist"] = 0 |
| | | element["error"] = "non" |
| | | while len(row) < len(rec): |
| | | row.append({"dist": 0, "error": "non"}) |
| | | for i in range(len(lab)): |
| | | self.space[i][0]["dist"] = i |
| | | self.space[i][0]["error"] = "del" |
| | | for j in range(len(rec)): |
| | | self.space[0][j]["dist"] = j |
| | | self.space[0][j]["error"] = "ins" |
| | | self.space[0][0]["error"] = "non" |
| | | for token in lab: |
| | | if token not in self.data and len(token) > 0: |
| | | self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0} |
| | | for token in rec: |
| | | if token not in self.data and len(token) > 0: |
| | | self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0} |
| | | # Computing edit distance |
| | | for i, lab_token in enumerate(lab): |
| | | for j, rec_token in enumerate(rec): |
| | | if i == 0 or j == 0: |
| | | continue |
| | | min_dist = sys.maxsize |
| | | min_error = "none" |
| | | dist = self.space[i - 1][j]["dist"] + self.cost["del"] |
| | | error = "del" |
| | | if dist < min_dist: |
| | | min_dist = dist |
| | | min_error = error |
| | | dist = self.space[i][j - 1]["dist"] + self.cost["ins"] |
| | | error = "ins" |
| | | if dist < min_dist: |
| | | min_dist = dist |
| | | min_error = error |
| | | if lab_token == rec_token.replace("<BIAS>", ""): |
| | | dist = self.space[i - 1][j - 1]["dist"] + self.cost["cor"] |
| | | error = "cor" |
| | | else: |
| | | dist = self.space[i - 1][j - 1]["dist"] + self.cost["sub"] |
| | | error = "sub" |
| | | if dist < min_dist: |
| | | min_dist = dist |
| | | min_error = error |
| | | self.space[i][j]["dist"] = min_dist |
| | | self.space[i][j]["error"] = min_error |
| | | # Tracing back |
| | | result = { |
| | | "lab": [], |
| | | "rec": [], |
| | | "code": [], |
| | | "all": 0, |
| | | "cor": 0, |
| | | "sub": 0, |
| | | "ins": 0, |
| | | "del": 0, |
| | | } |
| | | i = len(lab) - 1 |
| | | j = len(rec) - 1 |
| | | while True: |
| | | if self.space[i][j]["error"] == "cor": # correct |
| | | if len(lab[i]) > 0: |
| | | self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1 |
| | | self.data[lab[i]]["cor"] = self.data[lab[i]]["cor"] + 1 |
| | | result["all"] = result["all"] + 1 |
| | | result["cor"] = result["cor"] + 1 |
| | | result["lab"].insert(0, lab[i]) |
| | | result["rec"].insert(0, rec[j]) |
| | | result["code"].insert(0, Code.match) |
| | | i = i - 1 |
| | | j = j - 1 |
| | | elif self.space[i][j]["error"] == "sub": # substitution |
| | | if len(lab[i]) > 0: |
| | | self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1 |
| | | self.data[lab[i]]["sub"] = self.data[lab[i]]["sub"] + 1 |
| | | result["all"] = result["all"] + 1 |
| | | result["sub"] = result["sub"] + 1 |
| | | result["lab"].insert(0, lab[i]) |
| | | result["rec"].insert(0, rec[j]) |
| | | result["code"].insert(0, Code.substitution) |
| | | i = i - 1 |
| | | j = j - 1 |
| | | elif self.space[i][j]["error"] == "del": # deletion |
| | | if len(lab[i]) > 0: |
| | | self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1 |
| | | self.data[lab[i]]["del"] = self.data[lab[i]]["del"] + 1 |
| | | result["all"] = result["all"] + 1 |
| | | result["del"] = result["del"] + 1 |
| | | result["lab"].insert(0, lab[i]) |
| | | result["rec"].insert(0, "") |
| | | result["code"].insert(0, Code.deletion) |
| | | i = i - 1 |
| | | elif self.space[i][j]["error"] == "ins": # insertion |
| | | if len(rec[j]) > 0: |
| | | self.data[rec[j]]["ins"] = self.data[rec[j]]["ins"] + 1 |
| | | result["ins"] = result["ins"] + 1 |
| | | result["lab"].insert(0, "") |
| | | result["rec"].insert(0, rec[j]) |
| | | result["code"].insert(0, Code.insertion) |
| | | j = j - 1 |
| | | elif self.space[i][j]["error"] == "non": # starting point |
| | | break |
| | | else: # shouldn't reach here |
| | | print( |
| | | "this should not happen , i = {i} , j = {j} , error = {error}".format( |
| | | i=i, j=j, error=self.space[i][j]["error"] |
| | | ) |
| | | ) |
| | | return result |
| | | |
| | | def overall(self): |
| | | result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0} |
| | | for token in self.data: |
| | | result["all"] = result["all"] + self.data[token]["all"] |
| | | result["cor"] = result["cor"] + self.data[token]["cor"] |
| | | result["sub"] = result["sub"] + self.data[token]["sub"] |
| | | result["ins"] = result["ins"] + self.data[token]["ins"] |
| | | result["del"] = result["del"] + self.data[token]["del"] |
| | | return result |
| | | |
| | | def cluster(self, data): |
| | | result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0} |
| | | for token in data: |
| | | if token in self.data: |
| | | result["all"] = result["all"] + self.data[token]["all"] |
| | | result["cor"] = result["cor"] + self.data[token]["cor"] |
| | | result["sub"] = result["sub"] + self.data[token]["sub"] |
| | | result["ins"] = result["ins"] + self.data[token]["ins"] |
| | | result["del"] = result["del"] + self.data[token]["del"] |
| | | return result |
| | | |
| | | def keys(self): |
| | | return list(self.data.keys()) |
| | | |
| | | |
| | | def width(string): |
| | | return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) |
| | | |
| | | |
| | | def default_cluster(word): |
| | | unicode_names = [unicodedata.name(char) for char in word] |
| | | for i in reversed(range(len(unicode_names))): |
| | | if unicode_names[i].startswith("DIGIT"): # 1 |
| | | unicode_names[i] = "Number" # 'DIGIT' |
| | | elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[ |
| | | i |
| | | ].startswith("CJK COMPATIBILITY IDEOGRAPH"): |
| | | # 明 / 郎 |
| | | unicode_names[i] = "Mandarin" # 'CJK IDEOGRAPH' |
| | | elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[ |
| | | i |
| | | ].startswith("LATIN SMALL LETTER"): |
| | | # A / a |
| | | unicode_names[i] = "English" # 'LATIN LETTER' |
| | | elif unicode_names[i].startswith("HIRAGANA LETTER"): # は こ め |
| | | unicode_names[i] = "Japanese" # 'GANA LETTER' |
| | | elif ( |
| | | unicode_names[i].startswith("AMPERSAND") |
| | | or unicode_names[i].startswith("APOSTROPHE") |
| | | or unicode_names[i].startswith("COMMERCIAL AT") |
| | | or unicode_names[i].startswith("DEGREE CELSIUS") |
| | | or unicode_names[i].startswith("EQUALS SIGN") |
| | | or unicode_names[i].startswith("FULL STOP") |
| | | or unicode_names[i].startswith("HYPHEN-MINUS") |
| | | or unicode_names[i].startswith("LOW LINE") |
| | | or unicode_names[i].startswith("NUMBER SIGN") |
| | | or unicode_names[i].startswith("PLUS SIGN") |
| | | or unicode_names[i].startswith("SEMICOLON") |
| | | ): |
| | | # & / ' / @ / ℃ / = / . / - / _ / # / + / ; |
| | | del unicode_names[i] |
| | | else: |
| | | return "Other" |
| | | if len(unicode_names) == 0: |
| | | return "Other" |
| | | if len(unicode_names) == 1: |
| | | return unicode_names[0] |
| | | for i in range(len(unicode_names) - 1): |
| | | if unicode_names[i] != unicode_names[i + 1]: |
| | | return "Other" |
| | | return unicode_names[0] |
| | | |
| | | |
| | | def get_args(): |
| | | parser = argparse.ArgumentParser(description="wer cal") |
| | | parser.add_argument("--ref", type=str, help="Text input path") |
| | | parser.add_argument("--ref_ocr", type=str, help="Text input path") |
| | | parser.add_argument("--rec_name", type=str, action="append", default=[]) |
| | | parser.add_argument("--rec_file", type=str, action="append", default=[]) |
| | | parser.add_argument("--verbose", type=int, default=1, help="show") |
| | | parser.add_argument("--char", type=bool, default=True, help="show") |
| | | args = parser.parse_args() |
| | | return args |
| | | |
| | | |
| | | def main(args): |
| | | cluster_file = "" |
| | | ignore_words = set() |
| | | tochar = args.char |
| | | verbose = args.verbose |
| | | padding_symbol = " " |
| | | case_sensitive = False |
| | | max_words_per_line = sys.maxsize |
| | | split = None |
| | | |
| | | if not case_sensitive: |
| | | ig = set([w.upper() for w in ignore_words]) |
| | | ignore_words = ig |
| | | |
| | | default_clusters = {} |
| | | default_words = {} |
| | | ref_file = args.ref |
| | | ref_ocr = args.ref_ocr |
| | | rec_files = args.rec_file |
| | | rec_names = args.rec_name |
| | | assert len(rec_files) == len(rec_names) |
| | | |
| | | # load ocr |
| | | ref_ocr_dict = {} |
| | | with codecs.open(ref_ocr, "r", "utf-8") as fh: |
| | | for line in fh: |
| | | if "$" in line: |
| | | line = line.replace("$", " ") |
| | | if tochar: |
| | | array = characterize(line) |
| | | else: |
| | | array = line.strip().split() |
| | | if len(array) == 0: |
| | | continue |
| | | fid = array[0] |
| | | ref_ocr_dict[fid] = normalize(array[1:], ignore_words, case_sensitive, split) |
| | | |
| | | if split and not case_sensitive: |
| | | newsplit = dict() |
| | | for w in split: |
| | | words = split[w] |
| | | for i in range(len(words)): |
| | | words[i] = words[i].upper() |
| | | newsplit[w.upper()] = words |
| | | split = newsplit |
| | | |
| | | rec_sets = {} |
| | | calculators_dict = dict() |
| | | ub_wer_dict = dict() |
| | | hotwords_related_dict = dict() # 记录recall相关的内容 |
| | | for i, hyp_file in enumerate(rec_files): |
| | | rec_sets[rec_names[i]] = dict() |
| | | with codecs.open(hyp_file, "r", "utf-8") as fh: |
| | | for line in fh: |
| | | if tochar: |
| | | array = characterize(line) |
| | | else: |
| | | array = line.strip().split() |
| | | if len(array) == 0: |
| | | continue |
| | | fid = array[0] |
| | | rec_sets[rec_names[i]][fid] = normalize(array[1:], ignore_words, case_sensitive, split) |
| | | |
| | | calculators_dict[rec_names[i]] = Calculator() |
| | | ub_wer_dict[rec_names[i]] = {"u_wer": WordError(), "b_wer": WordError(), "wer": WordError()} |
| | | hotwords_related_dict[rec_names[i]] = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0} |
| | | # tp: 热词在label里,同时在rec里 |
| | | # tn: 热词不在label里,同时不在rec里 |
| | | # fp: 热词不在label里,但是在rec里 |
| | | # fn: 热词在label里,但是不在rec里 |
| | | |
| | | # record wrong label but in ocr |
| | | wrong_rec_but_in_ocr_dict = {} |
| | | for rec_name in rec_names: |
| | | wrong_rec_but_in_ocr_dict[rec_name] = 0 |
| | | |
| | | _file_total_len = 0 |
| | | with os.popen("cat {} | wc -l".format(ref_file)) as pipe: |
| | | _file_total_len = int(pipe.read().strip()) |
| | | |
| | | # compute error rate on the interaction of reference file and hyp file |
| | | for line in tqdm(open(ref_file, 'r', encoding='utf-8'), total=_file_total_len): |
| | | if tochar: |
| | | array = characterize(line) |
| | | else: |
| | | array = line.rstrip('\n').split() |
| | | if len(array) == 0: continue |
| | | fid = array[0] |
| | | lab = normalize(array[1:], ignore_words, case_sensitive, split) |
| | | |
| | | if verbose: |
| | | print('\nutt: %s' % fid) |
| | | |
| | | ocr_text = ref_ocr_dict[fid] |
| | | ocr_set = set(ocr_text) |
| | | print('ocr: {}'.format(" ".join(ocr_text))) |
| | | list_match = [] # 指label里面在ocr里面的内容 |
| | | list_not_mathch = [] |
| | | tmp_error = 0 |
| | | tmp_match = 0 |
| | | for index in range(len(lab)): |
| | | # text_list.append(uttlist[index+1]) |
| | | if lab[index] not in ocr_set: |
| | | tmp_error += 1 |
| | | list_not_mathch.append(lab[index]) |
| | | else: |
| | | tmp_match += 1 |
| | | list_match.append(lab[index]) |
| | | print('label in ocr: {}'.format(" ".join(list_match))) |
| | | |
| | | # for each reco file |
| | | base_wrong_ocr_wer = None |
| | | ocr_wrong_ocr_wer = None |
| | | |
| | | for rec_name in rec_names: |
| | | rec_set = rec_sets[rec_name] |
| | | if fid not in rec_set: |
| | | continue |
| | | rec = rec_set[fid] |
| | | |
| | | # print(rec) |
| | | for word in rec + lab: |
| | | if word not in default_words: |
| | | default_cluster_name = default_cluster(word) |
| | | if default_cluster_name not in default_clusters: |
| | | default_clusters[default_cluster_name] = {} |
| | | if word not in default_clusters[default_cluster_name]: |
| | | default_clusters[default_cluster_name][word] = 1 |
| | | default_words[word] = default_cluster_name |
| | | |
| | | result = calculators_dict[rec_name].calculate(lab.copy(), rec.copy()) |
| | | if verbose: |
| | | if result['all'] != 0: |
| | | wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] |
| | | else: |
| | | wer = 0.0 |
| | | print('WER(%s): %4.2f %%' % (rec_name, wer), end=' ') |
| | | print('N=%d C=%d S=%d D=%d I=%d' % |
| | | (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) |
| | | |
| | | |
| | | # print(result['rec']) |
| | | wrong_rec_but_in_ocr = [] |
| | | for idx in range(len(result['lab'])): |
| | | if result['lab'][idx] != "": |
| | | if result['lab'][idx] != result['rec'][idx].replace("<BIAS>", ""): |
| | | if result['lab'][idx] in list_match: |
| | | wrong_rec_but_in_ocr.append(result['lab'][idx]) |
| | | wrong_rec_but_in_ocr_dict[rec_name] += 1 |
| | | print('wrong_rec_but_in_ocr: {}'.format(" ".join(wrong_rec_but_in_ocr))) |
| | | |
| | | if rec_name == "base": |
| | | base_wrong_ocr_wer = len(wrong_rec_but_in_ocr) |
| | | if "ocr" in rec_name or "hot" in rec_name: |
| | | ocr_wrong_ocr_wer = len(wrong_rec_but_in_ocr) |
| | | if ocr_wrong_ocr_wer < base_wrong_ocr_wer: |
| | | print("{} {} helps, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer)) |
| | | elif ocr_wrong_ocr_wer > base_wrong_ocr_wer: |
| | | print("{} {} hurts, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer)) |
| | | |
| | | # recall = 0 |
| | | # false_alarm = 0 |
| | | # for idx in range(len(result['lab'])): |
| | | # if "<BIAS>" in result['rec'][idx]: |
| | | # if result['rec'][idx].replace("<BIAS>", "") in list_match: |
| | | # recall += 1 |
| | | # else: |
| | | # false_alarm += 1 |
| | | # print("bias hotwords recall: {}, fa: {}, list_match {}, recall: {:.2f}, fa: {:.2f}".format( |
| | | # recall, false_alarm, len(list_match), recall / len(list_match) if len(list_match) != 0 else 0, false_alarm / len(list_match) if len(list_match) != 0 else 0 |
| | | # )) |
| | | # tp: 热词在label里,同时在rec里 |
| | | # tn: 热词不在label里,同时不在rec里 |
| | | # fp: 热词不在label里,但是在rec里 |
| | | # fn: 热词在label里,但是不在rec里 |
| | | _rec_list = [word.replace("<BIAS>", "") for word in rec] |
| | | _label_list = [word for word in lab] |
| | | _tp = _tn = _fp = _fn = 0 |
| | | hot_true_list = [hotword for hotword in ocr_text if hotword in _label_list] |
| | | hot_bad_list = [hotword for hotword in ocr_text if hotword not in _label_list] |
| | | for badhotword in hot_bad_list: |
| | | count = len([word for word in _rec_list if word == badhotword]) |
| | | # print(f"bad {badhotword} count: {count}") |
| | | # for word in _rec_list: |
| | | # if badhotword == word: |
| | | # count += 1 |
| | | if count == 0: |
| | | hotwords_related_dict[rec_name]['tn'] += 1 |
| | | _tn += 1 |
| | | # fp: 0 |
| | | else: |
| | | hotwords_related_dict[rec_name]['fp'] += count |
| | | _fp += count |
| | | # tn: 0 |
| | | # if badhotword in _rec_list: |
| | | # hotwords_related_dict[rec_name]['fp'] += 1 |
| | | # else: |
| | | # hotwords_related_dict[rec_name]['tn'] += 1 |
| | | for hotword in hot_true_list: |
| | | true_count = len([word for word in _label_list if hotword == word]) |
| | | rec_count = len([word for word in _rec_list if hotword == word]) |
| | | # print(f"good {hotword} true_count: {true_count}, rec_count: {rec_count}") |
| | | if rec_count == true_count: |
| | | hotwords_related_dict[rec_name]['tp'] += true_count |
| | | _tp += true_count |
| | | elif rec_count > true_count: |
| | | hotwords_related_dict[rec_name]['tp'] += true_count |
| | | # fp: 不在label里,但是在rec里 |
| | | hotwords_related_dict[rec_name]['fp'] += rec_count - true_count |
| | | _tp += true_count |
| | | _fp += rec_count - true_count |
| | | else: |
| | | hotwords_related_dict[rec_name]['tp'] += rec_count |
| | | # fn: 热词在label里,但是不在rec里 |
| | | hotwords_related_dict[rec_name]['fn'] += true_count - rec_count |
| | | _tp += rec_count |
| | | _fn += true_count - rec_count |
| | | print("hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format( |
| | | _tp, _tn, _fp, _fn, sum([_tp, _tn, _fp, _fn]), _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0 |
| | | )) |
| | | |
| | | # if hotword in _rec_list: |
| | | # hotwords_related_dict[rec_name]['tp'] += 1 |
| | | # else: |
| | | # hotwords_related_dict[rec_name]['fn'] += 1 |
| | | # 计算uwer, bwer, wer |
| | | for code, rec_word, lab_word in zip(result["code"], result["rec"], result["lab"]): |
| | | if code == Code.match: |
| | | ub_wer_dict[rec_name]["wer"].ref_words += 1 |
| | | if lab_word in hot_true_list: |
| | | # tmp_ref.append(ref_tokens[ref_idx]) |
| | | ub_wer_dict[rec_name]["b_wer"].ref_words += 1 |
| | | else: |
| | | ub_wer_dict[rec_name]["u_wer"].ref_words += 1 |
| | | elif code == Code.substitution: |
| | | ub_wer_dict[rec_name]["wer"].ref_words += 1 |
| | | ub_wer_dict[rec_name]["wer"].errors[Code.substitution] += 1 |
| | | if lab_word in hot_true_list: |
| | | # tmp_ref.append(ref_tokens[ref_idx]) |
| | | ub_wer_dict[rec_name]["b_wer"].ref_words += 1 |
| | | ub_wer_dict[rec_name]["b_wer"].errors[Code.substitution] += 1 |
| | | else: |
| | | ub_wer_dict[rec_name]["u_wer"].ref_words += 1 |
| | | ub_wer_dict[rec_name]["u_wer"].errors[Code.substitution] += 1 |
| | | elif code == Code.deletion: |
| | | ub_wer_dict[rec_name]["wer"].ref_words += 1 |
| | | ub_wer_dict[rec_name]["wer"].errors[Code.deletion] += 1 |
| | | if lab_word in hot_true_list: |
| | | # tmp_ref.append(ref_tokens[ref_idx]) |
| | | ub_wer_dict[rec_name]["b_wer"].ref_words += 1 |
| | | ub_wer_dict[rec_name]["b_wer"].errors[Code.deletion] += 1 |
| | | else: |
| | | ub_wer_dict[rec_name]["u_wer"].ref_words += 1 |
| | | ub_wer_dict[rec_name]["u_wer"].errors[Code.deletion] += 1 |
| | | elif code == Code.insertion: |
| | | ub_wer_dict[rec_name]["wer"].errors[Code.insertion] += 1 |
| | | if rec_word in hot_true_list: |
| | | ub_wer_dict[rec_name]["b_wer"].errors[Code.insertion] += 1 |
| | | else: |
| | | ub_wer_dict[rec_name]["u_wer"].errors[Code.insertion] += 1 |
| | | |
| | | space = {} |
| | | space['lab'] = [] |
| | | space['rec'] = [] |
| | | for idx in range(len(result['lab'])): |
| | | len_lab = width(result['lab'][idx]) |
| | | len_rec = width(result['rec'][idx]) |
| | | length = max(len_lab, len_rec) |
| | | space['lab'].append(length - len_lab) |
| | | space['rec'].append(length - len_rec) |
| | | upper_lab = len(result['lab']) |
| | | upper_rec = len(result['rec']) |
| | | lab1, rec1 = 0, 0 |
| | | while lab1 < upper_lab or rec1 < upper_rec: |
| | | if verbose > 1: |
| | | print('lab(%s):' % fid.encode('utf-8'), end=' ') |
| | | else: |
| | | print('lab:', end=' ') |
| | | lab2 = min(upper_lab, lab1 + max_words_per_line) |
| | | for idx in range(lab1, lab2): |
| | | token = result['lab'][idx] |
| | | print('{token}'.format(token=token), end='') |
| | | for n in range(space['lab'][idx]): |
| | | print(padding_symbol, end='') |
| | | print(' ', end='') |
| | | print() |
| | | if verbose > 1: |
| | | print('rec(%s):' % fid.encode('utf-8'), end=' ') |
| | | else: |
| | | print('rec:', end=' ') |
| | | |
| | | rec2 = min(upper_rec, rec1 + max_words_per_line) |
| | | for idx in range(rec1, rec2): |
| | | token = result['rec'][idx] |
| | | print('{token}'.format(token=token), end='') |
| | | for n in range(space['rec'][idx]): |
| | | print(padding_symbol, end='') |
| | | print(' ', end='') |
| | | print() |
| | | # print('\n', end='\n') |
| | | lab1 = lab2 |
| | | rec1 = rec2 |
| | | print('\n', end='\n') |
| | | # break |
| | | if verbose: |
| | | print('===========================================================================') |
| | | print() |
| | | |
| | | print(wrong_rec_but_in_ocr_dict) |
| | | for rec_name in rec_names: |
| | | result = calculators_dict[rec_name].overall() |
| | | |
| | | if result['all'] != 0: |
| | | wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] |
| | | else: |
| | | wer = 0.0 |
| | | print('{} Overall -> {:4.2f} %'.format(rec_name, wer), end=' ') |
| | | print('N=%d C=%d S=%d D=%d I=%d' % |
| | | (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) |
| | | print(f"WER: {ub_wer_dict[rec_name]['wer'].get_result_string()}") |
| | | print(f"U-WER: {ub_wer_dict[rec_name]['u_wer'].get_result_string()}") |
| | | print(f"B-WER: {ub_wer_dict[rec_name]['b_wer'].get_result_string()}") |
| | | |
| | | print('hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%'.format( |
| | | hotwords_related_dict[rec_name]['tp'], |
| | | hotwords_related_dict[rec_name]['tn'], |
| | | hotwords_related_dict[rec_name]['fp'], |
| | | hotwords_related_dict[rec_name]['fn'], |
| | | sum([v for k, v in hotwords_related_dict[rec_name].items()]), |
| | | hotwords_related_dict[rec_name]['tp'] / ( |
| | | hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] |
| | | ) * 100 if hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] != 0 else 0 |
| | | )) |
| | | |
| | | # tp: 热词在label里,同时在rec里 |
| | | # tn: 热词不在label里,同时不在rec里 |
| | | # fp: 热词不在label里,但是在rec里 |
| | | # fn: 热词在label里,但是不在rec里 |
| | | if not verbose: |
| | | print() |
| | | print() |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | args = get_args() |
| | | |
| | | # print("") |
| | | print(args) |
| | | main(args) |
| | | |