From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 examples/industrial_data_pretraining/lcbnet/compute_wer_details.py |  221 +++++++++++++++++++++++++++++++-----------------------
 1 files changed, 127 insertions(+), 94 deletions(-)

diff --git a/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py b/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py
index e72d871..f672cc8 100755
--- a/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py
+++ b/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py
@@ -9,6 +9,7 @@
 from tqdm import tqdm
 import os
 import pdb
+
 remove_tag = False
 spacelist = [" ", "\t", "\r", "\n"]
 puncts = [
@@ -51,9 +52,9 @@
     def get_wer(self):
         assert self.ref_words != 0
         errors = (
-                self.errors[Code.substitution]
-                + self.errors[Code.insertion]
-                + self.errors[Code.deletion]
+            self.errors[Code.substitution]
+            + self.errors[Code.insertion]
+            + self.errors[Code.deletion]
         )
         return 100.0 * errors / self.ref_words
 
@@ -299,30 +300,30 @@
     for i in reversed(range(len(unicode_names))):
         if unicode_names[i].startswith("DIGIT"):  # 1
             unicode_names[i] = "Number"  # 'DIGIT'
-        elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[
-            i
-        ].startswith("CJK COMPATIBILITY IDEOGRAPH"):
+        elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[i].startswith(
+            "CJK COMPATIBILITY IDEOGRAPH"
+        ):
             # 鏄� / 铯�
             unicode_names[i] = "Mandarin"  # 'CJK IDEOGRAPH'
-        elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[
-            i
-        ].startswith("LATIN SMALL LETTER"):
+        elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[i].startswith(
+            "LATIN SMALL LETTER"
+        ):
             # A / a
             unicode_names[i] = "English"  # 'LATIN LETTER'
         elif unicode_names[i].startswith("HIRAGANA LETTER"):  # 銇� 銇� 銈�
             unicode_names[i] = "Japanese"  # 'GANA LETTER'
         elif (
-                unicode_names[i].startswith("AMPERSAND")
-                or unicode_names[i].startswith("APOSTROPHE")
-                or unicode_names[i].startswith("COMMERCIAL AT")
-                or unicode_names[i].startswith("DEGREE CELSIUS")
-                or unicode_names[i].startswith("EQUALS SIGN")
-                or unicode_names[i].startswith("FULL STOP")
-                or unicode_names[i].startswith("HYPHEN-MINUS")
-                or unicode_names[i].startswith("LOW LINE")
-                or unicode_names[i].startswith("NUMBER SIGN")
-                or unicode_names[i].startswith("PLUS SIGN")
-                or unicode_names[i].startswith("SEMICOLON")
+            unicode_names[i].startswith("AMPERSAND")
+            or unicode_names[i].startswith("APOSTROPHE")
+            or unicode_names[i].startswith("COMMERCIAL AT")
+            or unicode_names[i].startswith("DEGREE CELSIUS")
+            or unicode_names[i].startswith("EQUALS SIGN")
+            or unicode_names[i].startswith("FULL STOP")
+            or unicode_names[i].startswith("HYPHEN-MINUS")
+            or unicode_names[i].startswith("LOW LINE")
+            or unicode_names[i].startswith("NUMBER SIGN")
+            or unicode_names[i].startswith("PLUS SIGN")
+            or unicode_names[i].startswith("SEMICOLON")
         ):
             # & / ' / @ / 鈩� / = / . / - / _ / # / + / ;
             del unicode_names[i]
@@ -411,11 +412,13 @@
                 if len(array) == 0:
                     continue
                 fid = array[0]
-                rec_sets[rec_names[i]][fid] = normalize(array[1:], ignore_words, case_sensitive, split)
+                rec_sets[rec_names[i]][fid] = normalize(
+                    array[1:], ignore_words, case_sensitive, split
+                )
 
         calculators_dict[rec_names[i]] = Calculator()
         ub_wer_dict[rec_names[i]] = {"u_wer": WordError(), "b_wer": WordError(), "wer": WordError()}
-        hotwords_related_dict[rec_names[i]] = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}
+        hotwords_related_dict[rec_names[i]] = {"tp": 0, "tn": 0, "fp": 0, "fn": 0}
         # tp: 鐑瘝鍦╨abel閲岋紝鍚屾椂鍦╮ec閲�
         # tn: 鐑瘝涓嶅湪label閲岋紝鍚屾椂涓嶅湪rec閲�
         # fp: 鐑瘝涓嶅湪label閲岋紝浣嗘槸鍦╮ec閲�
@@ -431,21 +434,22 @@
         _file_total_len = int(pipe.read().strip())
 
     # compute error rate on the interaction of reference file and hyp file
-    for line in tqdm(open(ref_file, 'r', encoding='utf-8'), total=_file_total_len):
+    for line in tqdm(open(ref_file, "r", encoding="utf-8"), total=_file_total_len):
         if tochar:
             array = characterize(line)
         else:
-            array = line.rstrip('\n').split()
-        if len(array) == 0: continue
+            array = line.rstrip("\n").split()
+        if len(array) == 0:
+            continue
         fid = array[0]
         lab = normalize(array[1:], ignore_words, case_sensitive, split)
 
         if verbose:
-            print('\nutt: %s' % fid)
+            print("\nutt: %s" % fid)
 
         ocr_text = ref_ocr_dict[fid]
         ocr_set = set(ocr_text)
-        print('ocr: {}'.format(" ".join(ocr_text)))
+        print("ocr: {}".format(" ".join(ocr_text)))
         list_match = []  # 鎸噇abel閲岄潰鍦╫cr閲岄潰鐨勫唴瀹�
         list_not_mathch = []
         tmp_error = 0
@@ -458,7 +462,7 @@
             else:
                 tmp_match += 1
                 list_match.append(lab[index])
-        print('label in ocr: {}'.format(" ".join(list_match)))
+        print("label in ocr: {}".format(" ".join(list_match)))
 
         # for each reco file
         base_wrong_ocr_wer = None
@@ -482,33 +486,44 @@
 
             result = calculators_dict[rec_name].calculate(lab.copy(), rec.copy())
             if verbose:
-                if result['all'] != 0:
-                    wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
+                if result["all"] != 0:
+                    wer = (
+                        float(result["ins"] + result["sub"] + result["del"]) * 100.0 / result["all"]
+                    )
                 else:
                     wer = 0.0
-            print('WER(%s): %4.2f %%' % (rec_name, wer), end=' ')
-            print('N=%d C=%d S=%d D=%d I=%d' %
-                  (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
-
+            print("WER(%s): %4.2f %%" % (rec_name, wer), end=" ")
+            print(
+                "N=%d C=%d S=%d D=%d I=%d"
+                % (result["all"], result["cor"], result["sub"], result["del"], result["ins"])
+            )
 
             # print(result['rec'])
             wrong_rec_but_in_ocr = []
-            for idx in range(len(result['lab'])):
-                if result['lab'][idx] != "":
-                    if result['lab'][idx] != result['rec'][idx].replace("<BIAS>", ""):
-                        if result['lab'][idx] in list_match:
-                            wrong_rec_but_in_ocr.append(result['lab'][idx])
+            for idx in range(len(result["lab"])):
+                if result["lab"][idx] != "":
+                    if result["lab"][idx] != result["rec"][idx].replace("<BIAS>", ""):
+                        if result["lab"][idx] in list_match:
+                            wrong_rec_but_in_ocr.append(result["lab"][idx])
                             wrong_rec_but_in_ocr_dict[rec_name] += 1
-            print('wrong_rec_but_in_ocr: {}'.format(" ".join(wrong_rec_but_in_ocr)))
+            print("wrong_rec_but_in_ocr: {}".format(" ".join(wrong_rec_but_in_ocr)))
 
             if rec_name == "base":
                 base_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
             if "ocr" in rec_name or "hot" in rec_name:
                 ocr_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
                 if ocr_wrong_ocr_wer < base_wrong_ocr_wer:
-                    print("{} {} helps, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
+                    print(
+                        "{} {} helps, {} -> {}".format(
+                            fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer
+                        )
+                    )
                 elif ocr_wrong_ocr_wer > base_wrong_ocr_wer:
-                    print("{} {} hurts, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
+                    print(
+                        "{} {} hurts, {} -> {}".format(
+                            fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer
+                        )
+                    )
 
             # recall = 0
             # false_alarm = 0
@@ -537,11 +552,11 @@
                 #     if badhotword == word:
                 #         count += 1
                 if count == 0:
-                    hotwords_related_dict[rec_name]['tn'] += 1
+                    hotwords_related_dict[rec_name]["tn"] += 1
                     _tn += 1
                     # fp: 0
                 else:
-                    hotwords_related_dict[rec_name]['fp'] += count
+                    hotwords_related_dict[rec_name]["fp"] += count
                     _fp += count
                     # tn: 0
                 # if badhotword in _rec_list:
@@ -553,23 +568,30 @@
                 rec_count = len([word for word in _rec_list if hotword == word])
                 # print(f"good {hotword} true_count: {true_count}, rec_count: {rec_count}")
                 if rec_count == true_count:
-                    hotwords_related_dict[rec_name]['tp'] += true_count
+                    hotwords_related_dict[rec_name]["tp"] += true_count
                     _tp += true_count
                 elif rec_count > true_count:
-                    hotwords_related_dict[rec_name]['tp'] += true_count
+                    hotwords_related_dict[rec_name]["tp"] += true_count
                     # fp: 涓嶅湪label閲岋紝浣嗘槸鍦╮ec閲�
-                    hotwords_related_dict[rec_name]['fp'] += rec_count - true_count
+                    hotwords_related_dict[rec_name]["fp"] += rec_count - true_count
                     _tp += true_count
                     _fp += rec_count - true_count
                 else:
-                    hotwords_related_dict[rec_name]['tp'] += rec_count
+                    hotwords_related_dict[rec_name]["tp"] += rec_count
                     # fn: 鐑瘝鍦╨abel閲岋紝浣嗘槸涓嶅湪rec閲�
-                    hotwords_related_dict[rec_name]['fn'] += true_count - rec_count
+                    hotwords_related_dict[rec_name]["fn"] += true_count - rec_count
                     _tp += rec_count
                     _fn += true_count - rec_count
-            print("hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
-                _tp, _tn, _fp, _fn, sum([_tp, _tn, _fp, _fn]), _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0
-            ))
+            print(
+                "hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
+                    _tp,
+                    _tn,
+                    _fp,
+                    _fn,
+                    sum([_tp, _tn, _fp, _fn]),
+                    _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0,
+                )
+            )
 
             # if hotword in _rec_list:
             #     hotwords_related_dict[rec_name]['tp'] += 1
@@ -612,77 +634,89 @@
                         ub_wer_dict[rec_name]["u_wer"].errors[Code.insertion] += 1
 
             space = {}
-            space['lab'] = []
-            space['rec'] = []
-            for idx in range(len(result['lab'])):
-                len_lab = width(result['lab'][idx])
-                len_rec = width(result['rec'][idx])
+            space["lab"] = []
+            space["rec"] = []
+            for idx in range(len(result["lab"])):
+                len_lab = width(result["lab"][idx])
+                len_rec = width(result["rec"][idx])
                 length = max(len_lab, len_rec)
-                space['lab'].append(length - len_lab)
-                space['rec'].append(length - len_rec)
-            upper_lab = len(result['lab'])
-            upper_rec = len(result['rec'])
+                space["lab"].append(length - len_lab)
+                space["rec"].append(length - len_rec)
+            upper_lab = len(result["lab"])
+            upper_rec = len(result["rec"])
             lab1, rec1 = 0, 0
             while lab1 < upper_lab or rec1 < upper_rec:
                 if verbose > 1:
-                    print('lab(%s):' % fid.encode('utf-8'), end=' ')
+                    print("lab(%s):" % fid.encode("utf-8"), end=" ")
                 else:
-                    print('lab:', end=' ')
+                    print("lab:", end=" ")
                 lab2 = min(upper_lab, lab1 + max_words_per_line)
                 for idx in range(lab1, lab2):
-                    token = result['lab'][idx]
-                    print('{token}'.format(token=token), end='')
-                    for n in range(space['lab'][idx]):
-                        print(padding_symbol, end='')
-                    print(' ', end='')
+                    token = result["lab"][idx]
+                    print("{token}".format(token=token), end="")
+                    for n in range(space["lab"][idx]):
+                        print(padding_symbol, end="")
+                    print(" ", end="")
                 print()
                 if verbose > 1:
-                    print('rec(%s):' % fid.encode('utf-8'), end=' ')
+                    print("rec(%s):" % fid.encode("utf-8"), end=" ")
                 else:
-                    print('rec:', end=' ')
+                    print("rec:", end=" ")
 
                 rec2 = min(upper_rec, rec1 + max_words_per_line)
                 for idx in range(rec1, rec2):
-                    token = result['rec'][idx]
-                    print('{token}'.format(token=token), end='')
-                    for n in range(space['rec'][idx]):
-                        print(padding_symbol, end='')
-                    print(' ', end='')
+                    token = result["rec"][idx]
+                    print("{token}".format(token=token), end="")
+                    for n in range(space["rec"][idx]):
+                        print(padding_symbol, end="")
+                    print(" ", end="")
                 print()
                 # print('\n', end='\n')
                 lab1 = lab2
                 rec1 = rec2
-        print('\n', end='\n')
+        print("\n", end="\n")
         # break
     if verbose:
-        print('===========================================================================')
+        print("===========================================================================")
         print()
 
     print(wrong_rec_but_in_ocr_dict)
     for rec_name in rec_names:
         result = calculators_dict[rec_name].overall()
 
-        if result['all'] != 0:
-            wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
+        if result["all"] != 0:
+            wer = float(result["ins"] + result["sub"] + result["del"]) * 100.0 / result["all"]
         else:
             wer = 0.0
-        print('{} Overall -> {:4.2f} %'.format(rec_name, wer), end=' ')
-        print('N=%d C=%d S=%d D=%d I=%d' %
-              (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
+        print("{} Overall -> {:4.2f} %".format(rec_name, wer), end=" ")
+        print(
+            "N=%d C=%d S=%d D=%d I=%d"
+            % (result["all"], result["cor"], result["sub"], result["del"], result["ins"])
+        )
         print(f"WER: {ub_wer_dict[rec_name]['wer'].get_result_string()}")
         print(f"U-WER: {ub_wer_dict[rec_name]['u_wer'].get_result_string()}")
         print(f"B-WER: {ub_wer_dict[rec_name]['b_wer'].get_result_string()}")
 
-        print('hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%'.format(
-            hotwords_related_dict[rec_name]['tp'],
-            hotwords_related_dict[rec_name]['tn'],
-            hotwords_related_dict[rec_name]['fp'],
-            hotwords_related_dict[rec_name]['fn'],
-            sum([v for k, v in hotwords_related_dict[rec_name].items()]),
-            hotwords_related_dict[rec_name]['tp'] / (
-                    hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn']
-            ) * 100 if hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] != 0 else 0
-        ))
+        print(
+            "hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
+                hotwords_related_dict[rec_name]["tp"],
+                hotwords_related_dict[rec_name]["tn"],
+                hotwords_related_dict[rec_name]["fp"],
+                hotwords_related_dict[rec_name]["fn"],
+                sum([v for k, v in hotwords_related_dict[rec_name].items()]),
+                (
+                    hotwords_related_dict[rec_name]["tp"]
+                    / (
+                        hotwords_related_dict[rec_name]["tp"]
+                        + hotwords_related_dict[rec_name]["fn"]
+                    )
+                    * 100
+                    if hotwords_related_dict[rec_name]["tp"] + hotwords_related_dict[rec_name]["fn"]
+                    != 0
+                    else 0
+                ),
+            )
+        )
 
         # tp: 鐑瘝鍦╨abel閲岋紝鍚屾椂鍦╮ec閲�
         # tn: 鐑瘝涓嶅湪label閲岋紝鍚屾椂涓嶅湪rec閲�
@@ -695,8 +729,7 @@
 
 if __name__ == "__main__":
     args = get_args()
-    
+
     # print("")
     print(args)
     main(args)
-

--
Gitblit v1.9.1