From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
examples/industrial_data_pretraining/lcbnet/compute_wer_details.py | 221 +++++++++++++++++++++++++++++++-----------------------
1 files changed, 127 insertions(+), 94 deletions(-)
diff --git a/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py b/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py
index e72d871..f672cc8 100755
--- a/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py
+++ b/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py
@@ -9,6 +9,7 @@
from tqdm import tqdm
import os
import pdb
+
remove_tag = False
spacelist = [" ", "\t", "\r", "\n"]
puncts = [
@@ -51,9 +52,9 @@
def get_wer(self):
assert self.ref_words != 0
errors = (
- self.errors[Code.substitution]
- + self.errors[Code.insertion]
- + self.errors[Code.deletion]
+ self.errors[Code.substitution]
+ + self.errors[Code.insertion]
+ + self.errors[Code.deletion]
)
return 100.0 * errors / self.ref_words
@@ -299,30 +300,30 @@
for i in reversed(range(len(unicode_names))):
if unicode_names[i].startswith("DIGIT"): # 1
unicode_names[i] = "Number" # 'DIGIT'
- elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[
- i
- ].startswith("CJK COMPATIBILITY IDEOGRAPH"):
+ elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[i].startswith(
+ "CJK COMPATIBILITY IDEOGRAPH"
+ ):
# 鏄� / 铯�
unicode_names[i] = "Mandarin" # 'CJK IDEOGRAPH'
- elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[
- i
- ].startswith("LATIN SMALL LETTER"):
+ elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[i].startswith(
+ "LATIN SMALL LETTER"
+ ):
# A / a
unicode_names[i] = "English" # 'LATIN LETTER'
elif unicode_names[i].startswith("HIRAGANA LETTER"): # 銇� 銇� 銈�
unicode_names[i] = "Japanese" # 'GANA LETTER'
elif (
- unicode_names[i].startswith("AMPERSAND")
- or unicode_names[i].startswith("APOSTROPHE")
- or unicode_names[i].startswith("COMMERCIAL AT")
- or unicode_names[i].startswith("DEGREE CELSIUS")
- or unicode_names[i].startswith("EQUALS SIGN")
- or unicode_names[i].startswith("FULL STOP")
- or unicode_names[i].startswith("HYPHEN-MINUS")
- or unicode_names[i].startswith("LOW LINE")
- or unicode_names[i].startswith("NUMBER SIGN")
- or unicode_names[i].startswith("PLUS SIGN")
- or unicode_names[i].startswith("SEMICOLON")
+ unicode_names[i].startswith("AMPERSAND")
+ or unicode_names[i].startswith("APOSTROPHE")
+ or unicode_names[i].startswith("COMMERCIAL AT")
+ or unicode_names[i].startswith("DEGREE CELSIUS")
+ or unicode_names[i].startswith("EQUALS SIGN")
+ or unicode_names[i].startswith("FULL STOP")
+ or unicode_names[i].startswith("HYPHEN-MINUS")
+ or unicode_names[i].startswith("LOW LINE")
+ or unicode_names[i].startswith("NUMBER SIGN")
+ or unicode_names[i].startswith("PLUS SIGN")
+ or unicode_names[i].startswith("SEMICOLON")
):
# & / ' / @ / 鈩� / = / . / - / _ / # / + / ;
del unicode_names[i]
@@ -411,11 +412,13 @@
if len(array) == 0:
continue
fid = array[0]
- rec_sets[rec_names[i]][fid] = normalize(array[1:], ignore_words, case_sensitive, split)
+ rec_sets[rec_names[i]][fid] = normalize(
+ array[1:], ignore_words, case_sensitive, split
+ )
calculators_dict[rec_names[i]] = Calculator()
ub_wer_dict[rec_names[i]] = {"u_wer": WordError(), "b_wer": WordError(), "wer": WordError()}
- hotwords_related_dict[rec_names[i]] = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}
+ hotwords_related_dict[rec_names[i]] = {"tp": 0, "tn": 0, "fp": 0, "fn": 0}
# tp: 鐑瘝鍦╨abel閲岋紝鍚屾椂鍦╮ec閲�
# tn: 鐑瘝涓嶅湪label閲岋紝鍚屾椂涓嶅湪rec閲�
# fp: 鐑瘝涓嶅湪label閲岋紝浣嗘槸鍦╮ec閲�
@@ -431,21 +434,22 @@
_file_total_len = int(pipe.read().strip())
# compute error rate on the interaction of reference file and hyp file
- for line in tqdm(open(ref_file, 'r', encoding='utf-8'), total=_file_total_len):
+ for line in tqdm(open(ref_file, "r", encoding="utf-8"), total=_file_total_len):
if tochar:
array = characterize(line)
else:
- array = line.rstrip('\n').split()
- if len(array) == 0: continue
+ array = line.rstrip("\n").split()
+ if len(array) == 0:
+ continue
fid = array[0]
lab = normalize(array[1:], ignore_words, case_sensitive, split)
if verbose:
- print('\nutt: %s' % fid)
+ print("\nutt: %s" % fid)
ocr_text = ref_ocr_dict[fid]
ocr_set = set(ocr_text)
- print('ocr: {}'.format(" ".join(ocr_text)))
+ print("ocr: {}".format(" ".join(ocr_text)))
list_match = [] # 鎸噇abel閲岄潰鍦╫cr閲岄潰鐨勫唴瀹�
list_not_mathch = []
tmp_error = 0
@@ -458,7 +462,7 @@
else:
tmp_match += 1
list_match.append(lab[index])
- print('label in ocr: {}'.format(" ".join(list_match)))
+ print("label in ocr: {}".format(" ".join(list_match)))
# for each reco file
base_wrong_ocr_wer = None
@@ -482,33 +486,44 @@
result = calculators_dict[rec_name].calculate(lab.copy(), rec.copy())
if verbose:
- if result['all'] != 0:
- wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
+ if result["all"] != 0:
+ wer = (
+ float(result["ins"] + result["sub"] + result["del"]) * 100.0 / result["all"]
+ )
else:
wer = 0.0
- print('WER(%s): %4.2f %%' % (rec_name, wer), end=' ')
- print('N=%d C=%d S=%d D=%d I=%d' %
- (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
-
+ print("WER(%s): %4.2f %%" % (rec_name, wer), end=" ")
+ print(
+ "N=%d C=%d S=%d D=%d I=%d"
+ % (result["all"], result["cor"], result["sub"], result["del"], result["ins"])
+ )
# print(result['rec'])
wrong_rec_but_in_ocr = []
- for idx in range(len(result['lab'])):
- if result['lab'][idx] != "":
- if result['lab'][idx] != result['rec'][idx].replace("<BIAS>", ""):
- if result['lab'][idx] in list_match:
- wrong_rec_but_in_ocr.append(result['lab'][idx])
+ for idx in range(len(result["lab"])):
+ if result["lab"][idx] != "":
+ if result["lab"][idx] != result["rec"][idx].replace("<BIAS>", ""):
+ if result["lab"][idx] in list_match:
+ wrong_rec_but_in_ocr.append(result["lab"][idx])
wrong_rec_but_in_ocr_dict[rec_name] += 1
- print('wrong_rec_but_in_ocr: {}'.format(" ".join(wrong_rec_but_in_ocr)))
+ print("wrong_rec_but_in_ocr: {}".format(" ".join(wrong_rec_but_in_ocr)))
if rec_name == "base":
base_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
if "ocr" in rec_name or "hot" in rec_name:
ocr_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
if ocr_wrong_ocr_wer < base_wrong_ocr_wer:
- print("{} {} helps, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
+ print(
+ "{} {} helps, {} -> {}".format(
+ fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer
+ )
+ )
elif ocr_wrong_ocr_wer > base_wrong_ocr_wer:
- print("{} {} hurts, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
+ print(
+ "{} {} hurts, {} -> {}".format(
+ fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer
+ )
+ )
# recall = 0
# false_alarm = 0
@@ -537,11 +552,11 @@
# if badhotword == word:
# count += 1
if count == 0:
- hotwords_related_dict[rec_name]['tn'] += 1
+ hotwords_related_dict[rec_name]["tn"] += 1
_tn += 1
# fp: 0
else:
- hotwords_related_dict[rec_name]['fp'] += count
+ hotwords_related_dict[rec_name]["fp"] += count
_fp += count
# tn: 0
# if badhotword in _rec_list:
@@ -553,23 +568,30 @@
rec_count = len([word for word in _rec_list if hotword == word])
# print(f"good {hotword} true_count: {true_count}, rec_count: {rec_count}")
if rec_count == true_count:
- hotwords_related_dict[rec_name]['tp'] += true_count
+ hotwords_related_dict[rec_name]["tp"] += true_count
_tp += true_count
elif rec_count > true_count:
- hotwords_related_dict[rec_name]['tp'] += true_count
+ hotwords_related_dict[rec_name]["tp"] += true_count
# fp: 涓嶅湪label閲岋紝浣嗘槸鍦╮ec閲�
- hotwords_related_dict[rec_name]['fp'] += rec_count - true_count
+ hotwords_related_dict[rec_name]["fp"] += rec_count - true_count
_tp += true_count
_fp += rec_count - true_count
else:
- hotwords_related_dict[rec_name]['tp'] += rec_count
+ hotwords_related_dict[rec_name]["tp"] += rec_count
# fn: 鐑瘝鍦╨abel閲岋紝浣嗘槸涓嶅湪rec閲�
- hotwords_related_dict[rec_name]['fn'] += true_count - rec_count
+ hotwords_related_dict[rec_name]["fn"] += true_count - rec_count
_tp += rec_count
_fn += true_count - rec_count
- print("hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
- _tp, _tn, _fp, _fn, sum([_tp, _tn, _fp, _fn]), _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0
- ))
+ print(
+ "hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
+ _tp,
+ _tn,
+ _fp,
+ _fn,
+ sum([_tp, _tn, _fp, _fn]),
+ _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0,
+ )
+ )
# if hotword in _rec_list:
# hotwords_related_dict[rec_name]['tp'] += 1
@@ -612,77 +634,89 @@
ub_wer_dict[rec_name]["u_wer"].errors[Code.insertion] += 1
space = {}
- space['lab'] = []
- space['rec'] = []
- for idx in range(len(result['lab'])):
- len_lab = width(result['lab'][idx])
- len_rec = width(result['rec'][idx])
+ space["lab"] = []
+ space["rec"] = []
+ for idx in range(len(result["lab"])):
+ len_lab = width(result["lab"][idx])
+ len_rec = width(result["rec"][idx])
length = max(len_lab, len_rec)
- space['lab'].append(length - len_lab)
- space['rec'].append(length - len_rec)
- upper_lab = len(result['lab'])
- upper_rec = len(result['rec'])
+ space["lab"].append(length - len_lab)
+ space["rec"].append(length - len_rec)
+ upper_lab = len(result["lab"])
+ upper_rec = len(result["rec"])
lab1, rec1 = 0, 0
while lab1 < upper_lab or rec1 < upper_rec:
if verbose > 1:
- print('lab(%s):' % fid.encode('utf-8'), end=' ')
+ print("lab(%s):" % fid.encode("utf-8"), end=" ")
else:
- print('lab:', end=' ')
+ print("lab:", end=" ")
lab2 = min(upper_lab, lab1 + max_words_per_line)
for idx in range(lab1, lab2):
- token = result['lab'][idx]
- print('{token}'.format(token=token), end='')
- for n in range(space['lab'][idx]):
- print(padding_symbol, end='')
- print(' ', end='')
+ token = result["lab"][idx]
+ print("{token}".format(token=token), end="")
+ for n in range(space["lab"][idx]):
+ print(padding_symbol, end="")
+ print(" ", end="")
print()
if verbose > 1:
- print('rec(%s):' % fid.encode('utf-8'), end=' ')
+ print("rec(%s):" % fid.encode("utf-8"), end=" ")
else:
- print('rec:', end=' ')
+ print("rec:", end=" ")
rec2 = min(upper_rec, rec1 + max_words_per_line)
for idx in range(rec1, rec2):
- token = result['rec'][idx]
- print('{token}'.format(token=token), end='')
- for n in range(space['rec'][idx]):
- print(padding_symbol, end='')
- print(' ', end='')
+ token = result["rec"][idx]
+ print("{token}".format(token=token), end="")
+ for n in range(space["rec"][idx]):
+ print(padding_symbol, end="")
+ print(" ", end="")
print()
# print('\n', end='\n')
lab1 = lab2
rec1 = rec2
- print('\n', end='\n')
+ print("\n", end="\n")
# break
if verbose:
- print('===========================================================================')
+ print("===========================================================================")
print()
print(wrong_rec_but_in_ocr_dict)
for rec_name in rec_names:
result = calculators_dict[rec_name].overall()
- if result['all'] != 0:
- wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
+ if result["all"] != 0:
+ wer = float(result["ins"] + result["sub"] + result["del"]) * 100.0 / result["all"]
else:
wer = 0.0
- print('{} Overall -> {:4.2f} %'.format(rec_name, wer), end=' ')
- print('N=%d C=%d S=%d D=%d I=%d' %
- (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
+ print("{} Overall -> {:4.2f} %".format(rec_name, wer), end=" ")
+ print(
+ "N=%d C=%d S=%d D=%d I=%d"
+ % (result["all"], result["cor"], result["sub"], result["del"], result["ins"])
+ )
print(f"WER: {ub_wer_dict[rec_name]['wer'].get_result_string()}")
print(f"U-WER: {ub_wer_dict[rec_name]['u_wer'].get_result_string()}")
print(f"B-WER: {ub_wer_dict[rec_name]['b_wer'].get_result_string()}")
- print('hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%'.format(
- hotwords_related_dict[rec_name]['tp'],
- hotwords_related_dict[rec_name]['tn'],
- hotwords_related_dict[rec_name]['fp'],
- hotwords_related_dict[rec_name]['fn'],
- sum([v for k, v in hotwords_related_dict[rec_name].items()]),
- hotwords_related_dict[rec_name]['tp'] / (
- hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn']
- ) * 100 if hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] != 0 else 0
- ))
+ print(
+ "hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
+ hotwords_related_dict[rec_name]["tp"],
+ hotwords_related_dict[rec_name]["tn"],
+ hotwords_related_dict[rec_name]["fp"],
+ hotwords_related_dict[rec_name]["fn"],
+ sum([v for k, v in hotwords_related_dict[rec_name].items()]),
+ (
+ hotwords_related_dict[rec_name]["tp"]
+ / (
+ hotwords_related_dict[rec_name]["tp"]
+ + hotwords_related_dict[rec_name]["fn"]
+ )
+ * 100
+ if hotwords_related_dict[rec_name]["tp"] + hotwords_related_dict[rec_name]["fn"]
+ != 0
+ else 0
+ ),
+ )
+ )
# tp: 鐑瘝鍦╨abel閲岋紝鍚屾椂鍦╮ec閲�
# tn: 鐑瘝涓嶅湪label閲岋紝鍚屾椂涓嶅湪rec閲�
@@ -695,8 +729,7 @@
if __name__ == "__main__":
args = get_args()
-
+
# print("")
print(args)
main(args)
-
--
Gitblit v1.9.1