| | |
| | | from tqdm import tqdm |
| | | import os |
| | | import pdb |
| | | |
| | | remove_tag = False |
| | | spacelist = [" ", "\t", "\r", "\n"] |
| | | puncts = [ |
| | |
| | | def get_wer(self): |
| | | assert self.ref_words != 0 |
| | | errors = ( |
| | | self.errors[Code.substitution] |
| | | + self.errors[Code.insertion] |
| | | + self.errors[Code.deletion] |
| | | self.errors[Code.substitution] |
| | | + self.errors[Code.insertion] |
| | | + self.errors[Code.deletion] |
| | | ) |
| | | return 100.0 * errors / self.ref_words |
| | | |
| | |
| | | for i in reversed(range(len(unicode_names))): |
| | | if unicode_names[i].startswith("DIGIT"): # 1 |
| | | unicode_names[i] = "Number" # 'DIGIT' |
| | | elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[ |
| | | i |
| | | ].startswith("CJK COMPATIBILITY IDEOGRAPH"): |
| | | elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[i].startswith( |
| | | "CJK COMPATIBILITY IDEOGRAPH" |
| | | ): |
| | | # 明 / 郎 |
| | | unicode_names[i] = "Mandarin" # 'CJK IDEOGRAPH' |
| | | elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[ |
| | | i |
| | | ].startswith("LATIN SMALL LETTER"): |
| | | elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[i].startswith( |
| | | "LATIN SMALL LETTER" |
| | | ): |
| | | # A / a |
| | | unicode_names[i] = "English" # 'LATIN LETTER' |
| | | elif unicode_names[i].startswith("HIRAGANA LETTER"): # は こ め |
| | | unicode_names[i] = "Japanese" # 'GANA LETTER' |
| | | elif ( |
| | | unicode_names[i].startswith("AMPERSAND") |
| | | or unicode_names[i].startswith("APOSTROPHE") |
| | | or unicode_names[i].startswith("COMMERCIAL AT") |
| | | or unicode_names[i].startswith("DEGREE CELSIUS") |
| | | or unicode_names[i].startswith("EQUALS SIGN") |
| | | or unicode_names[i].startswith("FULL STOP") |
| | | or unicode_names[i].startswith("HYPHEN-MINUS") |
| | | or unicode_names[i].startswith("LOW LINE") |
| | | or unicode_names[i].startswith("NUMBER SIGN") |
| | | or unicode_names[i].startswith("PLUS SIGN") |
| | | or unicode_names[i].startswith("SEMICOLON") |
| | | unicode_names[i].startswith("AMPERSAND") |
| | | or unicode_names[i].startswith("APOSTROPHE") |
| | | or unicode_names[i].startswith("COMMERCIAL AT") |
| | | or unicode_names[i].startswith("DEGREE CELSIUS") |
| | | or unicode_names[i].startswith("EQUALS SIGN") |
| | | or unicode_names[i].startswith("FULL STOP") |
| | | or unicode_names[i].startswith("HYPHEN-MINUS") |
| | | or unicode_names[i].startswith("LOW LINE") |
| | | or unicode_names[i].startswith("NUMBER SIGN") |
| | | or unicode_names[i].startswith("PLUS SIGN") |
| | | or unicode_names[i].startswith("SEMICOLON") |
| | | ): |
| | | # & / ' / @ / ℃ / = / . / - / _ / # / + / ; |
| | | del unicode_names[i] |
| | |
| | | if len(array) == 0: |
| | | continue |
| | | fid = array[0] |
| | | rec_sets[rec_names[i]][fid] = normalize(array[1:], ignore_words, case_sensitive, split) |
| | | rec_sets[rec_names[i]][fid] = normalize( |
| | | array[1:], ignore_words, case_sensitive, split |
| | | ) |
| | | |
| | | calculators_dict[rec_names[i]] = Calculator() |
| | | ub_wer_dict[rec_names[i]] = {"u_wer": WordError(), "b_wer": WordError(), "wer": WordError()} |
| | | hotwords_related_dict[rec_names[i]] = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0} |
| | | hotwords_related_dict[rec_names[i]] = {"tp": 0, "tn": 0, "fp": 0, "fn": 0} |
| | | # tp: 热词在label里,同时在rec里 |
| | | # tn: 热词不在label里,同时不在rec里 |
| | | # fp: 热词不在label里,但是在rec里 |
| | |
| | | _file_total_len = int(pipe.read().strip()) |
| | | |
| | | # compute error rate on the interaction of reference file and hyp file |
| | | for line in tqdm(open(ref_file, 'r', encoding='utf-8'), total=_file_total_len): |
| | | for line in tqdm(open(ref_file, "r", encoding="utf-8"), total=_file_total_len): |
| | | if tochar: |
| | | array = characterize(line) |
| | | else: |
| | | array = line.rstrip('\n').split() |
| | | if len(array) == 0: continue |
| | | array = line.rstrip("\n").split() |
| | | if len(array) == 0: |
| | | continue |
| | | fid = array[0] |
| | | lab = normalize(array[1:], ignore_words, case_sensitive, split) |
| | | |
| | | if verbose: |
| | | print('\nutt: %s' % fid) |
| | | print("\nutt: %s" % fid) |
| | | |
| | | ocr_text = ref_ocr_dict[fid] |
| | | ocr_set = set(ocr_text) |
| | | print('ocr: {}'.format(" ".join(ocr_text))) |
| | | print("ocr: {}".format(" ".join(ocr_text))) |
| | | list_match = [] # 指label里面在ocr里面的内容 |
| | | list_not_mathch = [] |
| | | tmp_error = 0 |
| | |
| | | else: |
| | | tmp_match += 1 |
| | | list_match.append(lab[index]) |
| | | print('label in ocr: {}'.format(" ".join(list_match))) |
| | | print("label in ocr: {}".format(" ".join(list_match))) |
| | | |
| | | # for each reco file |
| | | base_wrong_ocr_wer = None |
| | |
| | | |
| | | result = calculators_dict[rec_name].calculate(lab.copy(), rec.copy()) |
| | | if verbose: |
| | | if result['all'] != 0: |
| | | wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] |
| | | if result["all"] != 0: |
| | | wer = ( |
| | | float(result["ins"] + result["sub"] + result["del"]) * 100.0 / result["all"] |
| | | ) |
| | | else: |
| | | wer = 0.0 |
| | | print('WER(%s): %4.2f %%' % (rec_name, wer), end=' ') |
| | | print('N=%d C=%d S=%d D=%d I=%d' % |
| | | (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) |
| | | |
| | | print("WER(%s): %4.2f %%" % (rec_name, wer), end=" ") |
| | | print( |
| | | "N=%d C=%d S=%d D=%d I=%d" |
| | | % (result["all"], result["cor"], result["sub"], result["del"], result["ins"]) |
| | | ) |
| | | |
| | | # print(result['rec']) |
| | | wrong_rec_but_in_ocr = [] |
| | | for idx in range(len(result['lab'])): |
| | | if result['lab'][idx] != "": |
| | | if result['lab'][idx] != result['rec'][idx].replace("<BIAS>", ""): |
| | | if result['lab'][idx] in list_match: |
| | | wrong_rec_but_in_ocr.append(result['lab'][idx]) |
| | | for idx in range(len(result["lab"])): |
| | | if result["lab"][idx] != "": |
| | | if result["lab"][idx] != result["rec"][idx].replace("<BIAS>", ""): |
| | | if result["lab"][idx] in list_match: |
| | | wrong_rec_but_in_ocr.append(result["lab"][idx]) |
| | | wrong_rec_but_in_ocr_dict[rec_name] += 1 |
| | | print('wrong_rec_but_in_ocr: {}'.format(" ".join(wrong_rec_but_in_ocr))) |
| | | print("wrong_rec_but_in_ocr: {}".format(" ".join(wrong_rec_but_in_ocr))) |
| | | |
| | | if rec_name == "base": |
| | | base_wrong_ocr_wer = len(wrong_rec_but_in_ocr) |
| | | if "ocr" in rec_name or "hot" in rec_name: |
| | | ocr_wrong_ocr_wer = len(wrong_rec_but_in_ocr) |
| | | if ocr_wrong_ocr_wer < base_wrong_ocr_wer: |
| | | print("{} {} helps, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer)) |
| | | print( |
| | | "{} {} helps, {} -> {}".format( |
| | | fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer |
| | | ) |
| | | ) |
| | | elif ocr_wrong_ocr_wer > base_wrong_ocr_wer: |
| | | print("{} {} hurts, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer)) |
| | | print( |
| | | "{} {} hurts, {} -> {}".format( |
| | | fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer |
| | | ) |
| | | ) |
| | | |
| | | # recall = 0 |
| | | # false_alarm = 0 |
| | |
| | | # if badhotword == word: |
| | | # count += 1 |
| | | if count == 0: |
| | | hotwords_related_dict[rec_name]['tn'] += 1 |
| | | hotwords_related_dict[rec_name]["tn"] += 1 |
| | | _tn += 1 |
| | | # fp: 0 |
| | | else: |
| | | hotwords_related_dict[rec_name]['fp'] += count |
| | | hotwords_related_dict[rec_name]["fp"] += count |
| | | _fp += count |
| | | # tn: 0 |
| | | # if badhotword in _rec_list: |
| | |
| | | rec_count = len([word for word in _rec_list if hotword == word]) |
| | | # print(f"good {hotword} true_count: {true_count}, rec_count: {rec_count}") |
| | | if rec_count == true_count: |
| | | hotwords_related_dict[rec_name]['tp'] += true_count |
| | | hotwords_related_dict[rec_name]["tp"] += true_count |
| | | _tp += true_count |
| | | elif rec_count > true_count: |
| | | hotwords_related_dict[rec_name]['tp'] += true_count |
| | | hotwords_related_dict[rec_name]["tp"] += true_count |
| | | # fp: 不在label里,但是在rec里 |
| | | hotwords_related_dict[rec_name]['fp'] += rec_count - true_count |
| | | hotwords_related_dict[rec_name]["fp"] += rec_count - true_count |
| | | _tp += true_count |
| | | _fp += rec_count - true_count |
| | | else: |
| | | hotwords_related_dict[rec_name]['tp'] += rec_count |
| | | hotwords_related_dict[rec_name]["tp"] += rec_count |
| | | # fn: 热词在label里,但是不在rec里 |
| | | hotwords_related_dict[rec_name]['fn'] += true_count - rec_count |
| | | hotwords_related_dict[rec_name]["fn"] += true_count - rec_count |
| | | _tp += rec_count |
| | | _fn += true_count - rec_count |
| | | print("hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format( |
| | | _tp, _tn, _fp, _fn, sum([_tp, _tn, _fp, _fn]), _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0 |
| | | )) |
| | | print( |
| | | "hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format( |
| | | _tp, |
| | | _tn, |
| | | _fp, |
| | | _fn, |
| | | sum([_tp, _tn, _fp, _fn]), |
| | | _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0, |
| | | ) |
| | | ) |
| | | |
| | | # if hotword in _rec_list: |
| | | # hotwords_related_dict[rec_name]['tp'] += 1 |
| | |
| | | ub_wer_dict[rec_name]["u_wer"].errors[Code.insertion] += 1 |
| | | |
| | | space = {} |
| | | space['lab'] = [] |
| | | space['rec'] = [] |
| | | for idx in range(len(result['lab'])): |
| | | len_lab = width(result['lab'][idx]) |
| | | len_rec = width(result['rec'][idx]) |
| | | space["lab"] = [] |
| | | space["rec"] = [] |
| | | for idx in range(len(result["lab"])): |
| | | len_lab = width(result["lab"][idx]) |
| | | len_rec = width(result["rec"][idx]) |
| | | length = max(len_lab, len_rec) |
| | | space['lab'].append(length - len_lab) |
| | | space['rec'].append(length - len_rec) |
| | | upper_lab = len(result['lab']) |
| | | upper_rec = len(result['rec']) |
| | | space["lab"].append(length - len_lab) |
| | | space["rec"].append(length - len_rec) |
| | | upper_lab = len(result["lab"]) |
| | | upper_rec = len(result["rec"]) |
| | | lab1, rec1 = 0, 0 |
| | | while lab1 < upper_lab or rec1 < upper_rec: |
| | | if verbose > 1: |
| | | print('lab(%s):' % fid.encode('utf-8'), end=' ') |
| | | print("lab(%s):" % fid.encode("utf-8"), end=" ") |
| | | else: |
| | | print('lab:', end=' ') |
| | | print("lab:", end=" ") |
| | | lab2 = min(upper_lab, lab1 + max_words_per_line) |
| | | for idx in range(lab1, lab2): |
| | | token = result['lab'][idx] |
| | | print('{token}'.format(token=token), end='') |
| | | for n in range(space['lab'][idx]): |
| | | print(padding_symbol, end='') |
| | | print(' ', end='') |
| | | token = result["lab"][idx] |
| | | print("{token}".format(token=token), end="") |
| | | for n in range(space["lab"][idx]): |
| | | print(padding_symbol, end="") |
| | | print(" ", end="") |
| | | print() |
| | | if verbose > 1: |
| | | print('rec(%s):' % fid.encode('utf-8'), end=' ') |
| | | print("rec(%s):" % fid.encode("utf-8"), end=" ") |
| | | else: |
| | | print('rec:', end=' ') |
| | | print("rec:", end=" ") |
| | | |
| | | rec2 = min(upper_rec, rec1 + max_words_per_line) |
| | | for idx in range(rec1, rec2): |
| | | token = result['rec'][idx] |
| | | print('{token}'.format(token=token), end='') |
| | | for n in range(space['rec'][idx]): |
| | | print(padding_symbol, end='') |
| | | print(' ', end='') |
| | | token = result["rec"][idx] |
| | | print("{token}".format(token=token), end="") |
| | | for n in range(space["rec"][idx]): |
| | | print(padding_symbol, end="") |
| | | print(" ", end="") |
| | | print() |
| | | # print('\n', end='\n') |
| | | lab1 = lab2 |
| | | rec1 = rec2 |
| | | print('\n', end='\n') |
| | | print("\n", end="\n") |
| | | # break |
| | | if verbose: |
| | | print('===========================================================================') |
| | | print("===========================================================================") |
| | | print() |
| | | |
| | | print(wrong_rec_but_in_ocr_dict) |
| | | for rec_name in rec_names: |
| | | result = calculators_dict[rec_name].overall() |
| | | |
| | | if result['all'] != 0: |
| | | wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] |
| | | if result["all"] != 0: |
| | | wer = float(result["ins"] + result["sub"] + result["del"]) * 100.0 / result["all"] |
| | | else: |
| | | wer = 0.0 |
| | | print('{} Overall -> {:4.2f} %'.format(rec_name, wer), end=' ') |
| | | print('N=%d C=%d S=%d D=%d I=%d' % |
| | | (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) |
| | | print("{} Overall -> {:4.2f} %".format(rec_name, wer), end=" ") |
| | | print( |
| | | "N=%d C=%d S=%d D=%d I=%d" |
| | | % (result["all"], result["cor"], result["sub"], result["del"], result["ins"]) |
| | | ) |
| | | print(f"WER: {ub_wer_dict[rec_name]['wer'].get_result_string()}") |
| | | print(f"U-WER: {ub_wer_dict[rec_name]['u_wer'].get_result_string()}") |
| | | print(f"B-WER: {ub_wer_dict[rec_name]['b_wer'].get_result_string()}") |
| | | |
| | | print('hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%'.format( |
| | | hotwords_related_dict[rec_name]['tp'], |
| | | hotwords_related_dict[rec_name]['tn'], |
| | | hotwords_related_dict[rec_name]['fp'], |
| | | hotwords_related_dict[rec_name]['fn'], |
| | | sum([v for k, v in hotwords_related_dict[rec_name].items()]), |
| | | hotwords_related_dict[rec_name]['tp'] / ( |
| | | hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] |
| | | ) * 100 if hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] != 0 else 0 |
| | | )) |
| | | print( |
| | | "hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format( |
| | | hotwords_related_dict[rec_name]["tp"], |
| | | hotwords_related_dict[rec_name]["tn"], |
| | | hotwords_related_dict[rec_name]["fp"], |
| | | hotwords_related_dict[rec_name]["fn"], |
| | | sum([v for k, v in hotwords_related_dict[rec_name].items()]), |
| | | ( |
| | | hotwords_related_dict[rec_name]["tp"] |
| | | / ( |
| | | hotwords_related_dict[rec_name]["tp"] |
| | | + hotwords_related_dict[rec_name]["fn"] |
| | | ) |
| | | * 100 |
| | | if hotwords_related_dict[rec_name]["tp"] + hotwords_related_dict[rec_name]["fn"] |
| | | != 0 |
| | | else 0 |
| | | ), |
| | | ) |
| | | ) |
| | | |
| | | # tp: 热词在label里,同时在rec里 |
| | | # tn: 热词不在label里,同时不在rec里 |
| | |
| | | |
| | | if __name__ == "__main__": |
| | | args = get_args() |
| | | |
| | | |
| | | # print("") |
| | | print(args) |
| | | main(args) |
| | | |